From f04f81080727febc88304b8edc17b2810806b321 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 11 Mar 2025 16:28:29 +0800 Subject: [PATCH 001/224] Normalize the sampling-based ltr with num pairs instead of grad. (#11322) * Normalize the sampling-based ltr with num pairs instead of grad. * lint. * Cleanup. * Define a new method. * rt param. * Revert "rt param." This reverts commit d6192083f904f6327a2baeafc561d5862cf9348d. --- doc/parameter.rst | 4 ++ doc/tutorials/learning_to_rank.rst | 2 - include/xgboost/base.h | 8 ++- python-package/xgboost/testing/ranking.py | 48 +++++++++++++++++ src/common/ranking_utils.cuh | 2 + src/common/ranking_utils.h | 13 +++-- src/objective/lambdarank_obj.cc | 21 ++++++-- src/objective/lambdarank_obj.cu | 61 ++++++++++++++-------- src/objective/lambdarank_obj.cuh | 6 ++- src/objective/lambdarank_obj.h | 8 +-- tests/cpp/objective/test_lambdarank_obj.cc | 39 +++++++------- tests/cpp/objective/test_lambdarank_obj.cu | 1 + tests/cpp/objective/test_lambdarank_obj.h | 5 +- 13 files changed, 154 insertions(+), 64 deletions(-) diff --git a/doc/parameter.rst b/doc/parameter.rst index e9a309c24766..2eedf39fe9de 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -540,6 +540,10 @@ These are parameters specific to learning to rank task. See :doc:`Learning to Ra Whether to normalize the leaf value by lambda gradient. This can sometimes stagnate the training progress. + .. versionchanged:: 3.1.0 + + When the ``mean`` method is used, it's normalized by the ``lambdarank_num_pair_per_sample`` instead of gradient. + * ``lambdarank_score_normalization`` [default = ``true``] .. versionadded:: 3.0.0 diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst index e1c1ab85a3eb..ea5309d31ca0 100644 --- a/doc/tutorials/learning_to_rank.rst +++ b/doc/tutorials/learning_to_rank.rst @@ -198,8 +198,6 @@ The learning to rank implementation has been significantly updated in 2.0 with a # 1.7 only supports sampling, while 2.0 and later use top-k as the default. # See above sections for the trade-off. "lambdarank_pair_method": "mean", - # Normalization was added in 2.0 - "lambdarank_normalization": False, # 1.7 uses the ranknet loss while later versions use the NDCG weighted loss "objective": "rank:pairwise", # 1.7 doesn't have this normalization. diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 64aab5c41b0c..4318bd808631 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -105,9 +105,13 @@ using bst_bin_t = std::int32_t; // NOLINT * @brief Type for data row index (sample). */ using bst_idx_t = std::uint64_t; // NOLINT -/*! \brief Type for tree node index. */ +/** + * \brief Type for tree node index. + */ using bst_node_t = std::int32_t; // NOLINT -/*! \brief Type for ranking group index. */ +/** + * @brief Type for ranking group index. + */ using bst_group_t = std::uint32_t; // NOLINT /** * @brief Type for indexing into output targets. diff --git a/python-package/xgboost/testing/ranking.py b/python-package/xgboost/testing/ranking.py index ebf88eceecf2..588c210750c8 100644 --- a/python-package/xgboost/testing/ranking.py +++ b/python-package/xgboost/testing/ranking.py @@ -105,6 +105,7 @@ def run_ranking_categorical(device: str) -> None: def run_normalization(device: str) -> None: """Test normalization.""" X, y, qid, _ = tm.make_ltr(2048, 4, 64, 3) + # top-k ltr = xgb.XGBRanker(objective="rank:pairwise", n_estimators=4, device=device) ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) e0 = ltr.evals_result() @@ -119,6 +120,53 @@ def run_normalization(device: str) -> None: e1 = ltr.evals_result() assert e1["validation_0"]["ndcg@32"][-1] > e0["validation_0"]["ndcg@32"][-1] + # mean + ltr = xgb.XGBRanker( + objective="rank:pairwise", + n_estimators=4, + device=device, + lambdarank_pair_method="mean", + lambdarank_normalization=True, + ) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e0 = ltr.evals_result() + + ltr = xgb.XGBRanker( + objective="rank:pairwise", + n_estimators=4, + device=device, + lambdarank_pair_method="mean", + lambdarank_normalization=False, + ) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e1 = ltr.evals_result() + # no normalization since the number of pairs is 1. + assert e1["validation_0"]["ndcg"][-1] == e0["validation_0"]["ndcg"][-1] + + # mean + ltr = xgb.XGBRanker( + objective="rank:pairwise", + n_estimators=4, + device=device, + lambdarank_pair_method="mean", + lambdarank_normalization=True, + lambdarank_num_pair_per_sample=4, + ) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e0 = ltr.evals_result() + + ltr = xgb.XGBRanker( + objective="rank:pairwise", + n_estimators=4, + device=device, + lambdarank_pair_method="mean", + lambdarank_normalization=False, + lambdarank_num_pair_per_sample=4, + ) + ltr.fit(X, y, qid=qid, eval_set=[(X, y)], eval_qid=[qid]) + e1 = ltr.evals_result() + assert e1["validation_0"]["ndcg"][-1] != e0["validation_0"]["ndcg"][-1] + def run_score_normalization(device: str, objective: str) -> None: """Test normalization by score differences.""" diff --git a/src/common/ranking_utils.cuh b/src/common/ranking_utils.cuh index 297f5157ecfb..9025dfdbc533 100644 --- a/src/common/ranking_utils.cuh +++ b/src/common/ranking_utils.cuh @@ -30,6 +30,8 @@ XGBOOST_DEVICE __forceinline__ std::size_t ThreadsForMean(std::size_t group_size std::size_t n_pairs) { return group_size * n_pairs; } +// Number of threads in a group divided by the number of samples in this group, returns +// the number of pairs for pair-wise ltr with sampling. XGBOOST_DEVICE __forceinline__ std::size_t PairsForGroup(std::size_t n_threads, std::size_t group_size) { return n_threads / group_size; diff --git a/src/common/ranking_utils.h b/src/common/ranking_utils.h index 8d98dfb913d7..16a264fdc967 100644 --- a/src/common/ranking_utils.h +++ b/src/common/ranking_utils.h @@ -115,6 +115,7 @@ struct LambdaRankParam : public XGBoostParameter { } [[nodiscard]] bool HasTruncation() const { return lambdarank_pair_method == PairMethod::kTopK; } + [[nodiscard]] bool IsMean() const { return lambdarank_pair_method == PairMethod::kMean; } // Used for evaluation metric and cache initialization, iterate through top-k or the whole list [[nodiscard]] auto TopK() const { @@ -180,7 +181,8 @@ class RankingCache { HostDeviceVector y_sorted_idx_cache_; // Cached labels sorted by the model HostDeviceVector y_ranked_by_model_; - // store rounding factor for objective for each group + // Rounding factor for CUDA deterministic floating point summation. One rounding factor + // for each ranking group. linalg::Vector roundings_; // rounding factor for cost HostDeviceVector cost_rounding_; @@ -215,6 +217,9 @@ class RankingCache { if (!info.weights_.Empty()) { CHECK_EQ(Groups(), info.weights_.Size()) << error::GroupWeight(); } + if (param_.HasTruncation()) { + CHECK_GE(param_.NumPair(), 1); + } } [[nodiscard]] std::size_t MaxPositionSize() const { // Use truncation level as bound. @@ -267,21 +272,21 @@ class RankingCache { } // CUDA cache getters, the cache is shared between metric and objective, some of these - // fields are lazy initialized to avoid unnecessary allocation. + // fields are initialized lazily to avoid unnecessary allocation. [[nodiscard]] common::Span CUDAThreadsGroupPtr() const { CHECK(!threads_group_ptr_.Empty()); return threads_group_ptr_.ConstDeviceSpan(); } [[nodiscard]] std::size_t CUDAThreads() const { return n_cuda_threads_; } - linalg::VectorView CUDARounding(Context const* ctx) { + [[nodiscard]] linalg::VectorView CUDARounding(Context const* ctx) { if (roundings_.Size() == 0) { roundings_.SetDevice(ctx->Device()); roundings_.Reshape(Groups()); } return roundings_.View(ctx->Device()); } - common::Span CUDACostRounding(Context const* ctx) { + [[nodiscard]] common::Span CUDACostRounding(Context const* ctx) { if (cost_rounding_.Size() == 0) { cost_rounding_.SetDevice(ctx->Device()); cost_rounding_.Resize(1); diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index b19f72e1d46f..45ea357425b0 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -225,10 +225,23 @@ class LambdaRankObj : public FitIntercept { }; MakePairs(ctx_, iter, p_cache_, g, g_label, g_rank, loop); - if (sum_lambda > 0.0 && param_.lambdarank_normalization) { - double norm = std::log2(1.0 + sum_lambda) / sum_lambda; - std::transform(g_gpair.Values().data(), g_gpair.Values().data() + g_gpair.Size(), - g_gpair.Values().data(), [norm](GradientPair const& g) { return g * norm; }); + if (param_.lambdarank_normalization) { + double norm = 1.0; + if (param_.IsMean()) { + // Normalize using the number of pairs for mean. + auto n_pairs = this->p_cache_->Param().NumPair(); + auto scale = 1.0 / static_cast(n_pairs); + norm = scale; + } else { + // Normalize using gradient for top-k. + if (sum_lambda > 0.0) { + norm = std::log2(1.0 + sum_lambda) / sum_lambda; + } + } + if (norm != 1.0) { + std::transform(linalg::begin(g_gpair), linalg::end(g_gpair), linalg::begin(g_gpair), + [norm](GradientPair const& g) { return g * norm; }); + } } auto w_norm = p_cache_->WeightNorm(); diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index eae067a56649..8e4dc8c36252 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -4,19 +4,18 @@ * \brief CUDA implementation of lambdarank. */ #include // for DMLC_REGISTRY_FILE_TAG - #include // for fill_n #include // for for_each_n #include // for make_counting_iterator #include // for make_zip_iterator #include // for make_tuple, tuple, tie, get -#include // for min -#include // for assert -#include // for abs, log2, isinf -#include // for size_t -#include // for int32_t -#include // for shared_ptr +#include // for min +#include // for assert +#include // for abs, log2, isinf +#include // for size_t +#include // for int32_t +#include // for shared_ptr #include #include "../common/algorithm.cuh" // for SegmentedArgSort @@ -33,7 +32,7 @@ #include "xgboost/host_device_vector.h" // for HostDeviceVector #include "xgboost/linalg.h" // for VectorView, Range, Vector #include "xgboost/logging.h" -#include "xgboost/span.h" // for Span +#include "xgboost/span.h" // for Span namespace xgboost::obj { DMLC_REGISTRY_FILE_TAG(lambdarank_obj_cu); @@ -84,7 +83,7 @@ struct GetGradOp { MakePairsOp make_pair; Delta delta; - bool need_update; + bool const need_update; auto __device__ operator()(std::size_t idx) -> GradCostNorm { auto const& args = make_pair.args; @@ -97,6 +96,7 @@ struct GetGradOp { auto g_predt = args.predts.subspan(data_group_begin, n_data); auto g_gpair = args.gpairs.Slice(linalg::Range(data_group_begin, data_group_begin + n_data)); auto g_rank = args.d_sorted_idx.subspan(data_group_begin, n_data); + auto n_pairs = args.n_pairs; auto [i, j] = make_pair(idx, g); @@ -110,7 +110,9 @@ struct GetGradOp { double cost{0}; - auto delta_op = [&](auto const&... args) { return delta(args..., g); }; + auto delta_op = [&](auto const&... args) { + return delta(args..., g); + }; GradientPair pg = LambdaGrad(g_label, g_predt, g_rank, rank_high, rank_low, delta_op, args.ti_plus, args.tj_minus, &cost); @@ -120,7 +122,6 @@ struct GetGradOp { if (need_update) { // second run, update the gradient - auto ng = Repulse(pg); auto gr = args.d_roundings(g); @@ -155,6 +156,7 @@ struct GetGradOp { } } } + return thrust::make_tuple(GradientPair{std::abs(pg.GetGrad()), std::abs(pg.GetHess())}, std::abs(cost), -2.0 * static_cast(pg.GetGrad())); } @@ -217,12 +219,12 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptr(l), thrust::get<1>(r)); double sum_lambda = thrust::get<2>(l) + thrust::get<2>(r); - return thrust::make_tuple(GradientPair{std::abs(grad), std::abs(hess)}, cost, sum_lambda); + return thrust::make_tuple(GradientPair{grad, hess}, cost, sum_lambda); }; auto init = thrust::make_tuple(GradientPair{0.0f, 0.0f}, 0.0, 0.0); common::Span d_max_lambdas = p_cache->MaxLambdas(ctx, n_groups); CHECK_EQ(n_groups * sizeof(GradCostNorm), d_max_lambdas.size_bytes()); - + // Reduce by group. std::size_t bytes; cub::DeviceSegmentedReduce::Reduce(nullptr, bytes, val_it, d_max_lambdas.data(), n_groups, d_threads_group_ptr.data(), d_threads_group_ptr.data() + 1, @@ -269,22 +271,35 @@ void CalcGrad(Context const* ctx, MetaInfo const& info, std::shared_ptrWeightNorm(); - auto norm = p_cache->Param().lambdarank_normalization; + auto need_norm = p_cache->Param().lambdarank_normalization; + auto n_pairs = p_cache->Param().NumPair(); + bool is_mean = p_cache->Param().IsMean(); + CHECK_EQ(is_mean, !has_truncation); thrust::for_each_n(ctx->CUDACtx()->CTP(), thrust::make_counting_iterator(0ul), d_gpair.Size(), [=] XGBOOST_DEVICE(std::size_t i) mutable { auto g = dh::SegmentId(d_gptr, i); - auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); - // Normalization - if (sum_lambda > 0.0 && norm) { - double norm = std::log2(1.0 + sum_lambda) / sum_lambda; + if (need_norm) { + double norm = 1.0; + if (has_truncation) { + // Normalize using gradient for top-k. + auto sum_lambda = thrust::get<2>(d_max_lambdas[g]); + if (sum_lambda > 0.0) { + norm = std::log2(1.0 + sum_lambda) / sum_lambda; + } + } else { + // Normalize using the number of pairs for mean. + double scale = 1.0 / static_cast(n_pairs); + norm = scale; + } d_gpair(i, 0) *= norm; } + d_gpair(i, 0) *= (d_weights[g] * w_norm); }); } /** - * \brief Handles boilerplate code like getting device span. + * @brief Handles boilerplate code like getting device spans. */ template void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const& preds, @@ -304,7 +319,6 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const out_gpair->Reshape(preds.Size(), 1); CHECK(p_cache); - auto d_rounding = p_cache->CUDARounding(ctx); auto d_cost_rounding = p_cache->CUDACostRounding(ctx); @@ -327,9 +341,10 @@ void Launch(Context const* ctx, std::int32_t iter, HostDeviceVector const d_y_sorted_idx = SortY(ctx, info, rank_idx, p_cache); } - KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr, - rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data(), - d_y_sorted_idx, iter}; + auto n_pairs = p_cache->Param().NumPair(); + KernelInputs args{ti_plus, tj_minus, li, lj, d_gptr, d_threads_group_ptr, + rank_idx, label, predts, gpairs, d_rounding, d_cost_rounding.data(), + n_pairs, d_y_sorted_idx, iter}; // dispatch based on unbiased and truncation if (p_cache->Param().HasTruncation()) { diff --git a/src/objective/lambdarank_obj.cuh b/src/objective/lambdarank_obj.cuh index e1a78f905434..ce95304197da 100644 --- a/src/objective/lambdarank_obj.cuh +++ b/src/objective/lambdarank_obj.cuh @@ -66,6 +66,7 @@ struct KernelInputs { linalg::VectorView d_roundings; double const *d_cost_rounding; + ltr::position_t const n_pairs; common::Span d_y_sorted_idx; std::int32_t iter; @@ -136,9 +137,10 @@ struct MakePairsOp { // The index pointing to the first element of the next bucket std::size_t right_bound = n_data - n_rights; - thrust::minstd_rand rng(args.iter); + std::uint32_t seed = args.iter * (static_cast(args.d_group_ptr.size()) - 1) + g; + thrust::minstd_rand rng(seed); auto pair_idx = i; - rng.discard(sample_pair_idx * n_data + g + pair_idx); // fixme + rng.discard(idx - args.d_threads_group_ptr[g]); // idx within group thrust::uniform_int_distribution dist(0, n_lefts + n_rights - 1); auto ridx = dist(rng); SPAN_CHECK(ridx < n_lefts + n_rights); diff --git a/src/objective/lambdarank_obj.h b/src/objective/lambdarank_obj.h index 113fce832492..56e57582eece 100644 --- a/src/objective/lambdarank_obj.h +++ b/src/objective/lambdarank_obj.h @@ -227,15 +227,16 @@ void MakePairs(Context const* ctx, std::int32_t iter, ltr::position_t cnt = group_ptr[g + 1] - group_ptr[g]; if (cache->Param().HasTruncation()) { - for (std::size_t i = 0; i < std::min(cnt, cache->Param().NumPair()); ++i) { + for (std::size_t i = 0, n = std::min(cnt, cache->Param().NumPair()); i < n; ++i) { for (std::size_t j = i + 1; j < cnt; ++j) { op(i, j); } } } else { CHECK_EQ(g_rank.size(), g_label.Size()); - std::minstd_rand rnd(iter); - rnd.discard(g); // fixme(jiamingy): honor the global seed + + std::uint32_t seed = iter * (static_cast(group_ptr.size()) - 1) + g; + std::minstd_rand rnd(seed); // sort label according to the rank list auto it = common::MakeIndexTransformIter( [&g_rank, &g_label](std::size_t idx) { return g_label(g_rank[idx]); }); @@ -244,7 +245,6 @@ void MakePairs(Context const* ctx, std::int32_t iter, // permutation iterator to get the original label auto rev_it = common::MakeIndexTransformIter( [&](std::size_t idx) { return g_label(g_rank[y_sorted_idx[idx]]); }); - for (std::size_t i = 0; i < cnt;) { std::size_t j = i + 1; // find the bucket boundary diff --git a/tests/cpp/objective/test_lambdarank_obj.cc b/tests/cpp/objective/test_lambdarank_obj.cc index db8472a2a7dd..7d1639e4fb79 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cc +++ b/tests/cpp/objective/test_lambdarank_obj.cc @@ -3,25 +3,26 @@ */ #include "test_lambdarank_obj.h" -#include // for Test, Message, TestPartResult, CmpHel... - -#include // for sort -#include // for size_t -#include // for initializer_list -#include // for unique_ptr, shared_ptr, make_shared -#include // for iota -#include // for char_traits, basic_string, string -#include // for vector - -#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam -#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload -#include "xgboost/base.h" // for GradientPair, bst_group_t, Args -#include "xgboost/context.h" // for Context -#include "xgboost/data.h" // for MetaInfo, DMatrix -#include "xgboost/host_device_vector.h" // for HostDeviceVector -#include "xgboost/linalg.h" // for Tensor, All, TensorView -#include "xgboost/objective.h" // for ObjFunction -#include "xgboost/span.h" // for Span +#include // for Test, Message, TestPartResult, CmpHel... + +#include // for sort +#include // for size_t +#include // for initializer_list +#include // for unique_ptr, shared_ptr, make_shared +#include // for iota +#include // for char_traits, basic_string, string +#include // for vector + +#include "../../../src/common/ranking_utils.h" // for NDCGCache, LambdaRankParam +#include "../../../src/objective/lambdarank_obj.h" // for MAPStat, MakePairs +#include "../helpers.h" // for CheckRankingObjFunction, CheckConfigReload +#include "xgboost/base.h" // for GradientPair, bst_group_t, Args +#include "xgboost/context.h" // for Context +#include "xgboost/data.h" // for MetaInfo, DMatrix +#include "xgboost/host_device_vector.h" // for HostDeviceVector +#include "xgboost/linalg.h" // for Tensor, All, TensorView +#include "xgboost/objective.h" // for ObjFunction +#include "xgboost/span.h" // for Span namespace xgboost::obj { TEST(LambdaRank, NDCGJsonIO) { diff --git a/tests/cpp/objective/test_lambdarank_obj.cu b/tests/cpp/objective/test_lambdarank_obj.cu index c80ec20fc63d..d33273678662 100644 --- a/tests/cpp/objective/test_lambdarank_obj.cu +++ b/tests/cpp/objective/test_lambdarank_obj.cu @@ -55,6 +55,7 @@ void TestGPUMakePair() { linalg::MatrixView{common::Span{}, {0}, DeviceOrd::CUDA(0)}, dg, nullptr, + 1, y_sorted_idx, 0}; return args; diff --git a/tests/cpp/objective/test_lambdarank_obj.h b/tests/cpp/objective/test_lambdarank_obj.h index 9539f1a3003e..4383a44d1a75 100644 --- a/tests/cpp/objective/test_lambdarank_obj.h +++ b/tests/cpp/objective/test_lambdarank_obj.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023, XGBoost Contributors + * Copyright 2023-2025, XGBoost Contributors */ #ifndef XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ #define XGBOOST_OBJECTIVE_TEST_LAMBDARANK_OBJ_H_ @@ -10,11 +10,8 @@ #include // for ObjFunction #include // for shared_ptr, make_shared -#include // for iota -#include // for vector #include "../../../src/common/ranking_utils.h" // for LambdaRankParam, MAPCache -#include "../../../src/objective/lambdarank_obj.h" // for MAPStat #include "../helpers.h" // for EmptyDMatrix namespace xgboost::obj { From 58e908c6649a166c3b47ea5e06e2a8539718a2b0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 11 Mar 2025 19:08:26 +0800 Subject: [PATCH 002/224] [doc] Fix version change. (#11328) --- doc/parameter.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/parameter.rst b/doc/parameter.rst index 2eedf39fe9de..0125dbdae9d1 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -540,7 +540,7 @@ These are parameters specific to learning to rank task. See :doc:`Learning to Ra Whether to normalize the leaf value by lambda gradient. This can sometimes stagnate the training progress. - .. versionchanged:: 3.1.0 + .. versionchanged:: 3.0.0 When the ``mean`` method is used, it's normalized by the ``lambdarank_num_pair_per_sample`` instead of gradient. From a57657bb365e2d2e081b8ec77297d2f889e21517 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Wed, 12 Mar 2025 11:08:59 +0100 Subject: [PATCH 003/224] [sycl] fix init estimations (#11331) --- src/tree/fit_stump.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/fit_stump.cc b/src/tree/fit_stump.cc index 8fdcb3131646..144abcbd8131 100644 --- a/src/tree/fit_stump.cc +++ b/src/tree/fit_stump.cc @@ -74,7 +74,7 @@ void FitStump(Context const* ctx, MetaInfo const& info, linalg::MatrixReshape(n_targets); gpair.SetDevice(ctx->Device()); - auto gpair_t = gpair.View(ctx->Device()); + auto gpair_t = gpair.View(ctx->Device().IsSycl() ? DeviceOrd::CPU() : ctx->Device()); ctx->IsCUDA() ? cuda_impl::FitStump(ctx, info, gpair_t, out->View(ctx->Device())) : cpu_impl::FitStump(ctx, info, gpair_t, out->HostView()); } From 5927d5da838ae9e7f5dd0e9e7a9a88dd4d8199e4 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 13 Mar 2025 06:35:13 +0800 Subject: [PATCH 004/224] [EM] Disable the `on_host` option for CPU inputs. (#11333) --- src/data/sparse_page_dmatrix.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index f3a26a391d9a..be726c80b48b 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -160,11 +160,11 @@ BatchSet SparsePageDMatrix::GetGradientIndex(Context const *ct CHECK_GE(param.max_bin, 2); } detail::CheckEmpty(batch_param_, param); - auto id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_); + auto id = MakeCache(this, ".gradient_index.page", false, cache_prefix_, &cache_info_); if (!cache_info_.at(id)->written || detail::RegenGHist(batch_param_, param)) { this->InitializeSparsePage(ctx); cache_info_.erase(id); - id = MakeCache(this, ".gradient_index.page", on_host_, cache_prefix_, &cache_info_); + id = MakeCache(this, ".gradient_index.page", false, cache_prefix_, &cache_info_); LOG(INFO) << "Generating new Gradient Index."; // Use sorted sketch for approx. auto sorted_sketch = param.regen; From d63d98b2ebfc0f847a4c1b1f40a02772fe065216 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 14 Mar 2025 14:07:08 +0800 Subject: [PATCH 005/224] Auto re-coding for the CPU predictor. (#11315) - inplace predict - DMatrix - Raise an error when input categories are floating points. This PR integrates the re-coder into the CPU predictor by defining an accessor for calculating the mapped values on the fly. For a numeric-only dataset, it's a no op. --- python-package/xgboost/testing/data.py | 4 +- python-package/xgboost/testing/ordinal.py | 181 +++++++++- src/common/error_msg.h | 7 + src/data/adapter.h | 12 +- src/data/cat_container.cc | 24 +- src/data/cat_container.cu | 50 ++- src/data/cat_container.h | 42 ++- src/data/device_adapter.cuh | 2 +- src/data/simple_dmatrix.cc | 5 + src/data/simple_dmatrix.cu | 2 +- src/data/sparse_page_dmatrix.cc | 4 +- src/encoder/ordinal.h | 7 + src/gbm/gbtree.cc | 12 + src/gbm/gbtree_model.cc | 14 +- src/gbm/gbtree_model.h | 23 +- src/learner.cc | 2 +- src/predictor/cpu_predictor.cc | 317 +++++++++++------- src/predictor/predict_fn.h | 42 +++ tests/cpp/data/test_cat_container.cu | 30 +- tests/python-gpu/test_gpu_updaters.py | 18 +- tests/python-gpu/test_gpu_with_sklearn.py | 11 +- tests/python/test_demos.py | 5 + tests/python/test_ordinal.py | 25 ++ tests/python/test_predict.py | 8 +- .../test_with_dask/test_with_dask.py | 2 +- 25 files changed, 673 insertions(+), 176 deletions(-) diff --git a/python-package/xgboost/testing/data.py b/python-package/xgboost/testing/data.py index 7124c48d9d0d..36367cdc26db 100644 --- a/python-package/xgboost/testing/data.py +++ b/python-package/xgboost/testing/data.py @@ -150,9 +150,11 @@ def pd_dtypes() -> Generator: # Categorical orig = orig.astype("category") + for c in orig.columns: + orig[c] = orig[c].cat.rename_categories(int) for Null in (np.nan, None, pd.NA): df = pd.DataFrame( - {"f0": [1.0, 2.0, Null, 3.0], "f1": [3.0, 2.0, Null, 1.0]}, + {"f0": [1, 2, Null, 3], "f1": [3, 2, Null, 1]}, dtype=pd.CategoricalDtype(), ) yield orig, df diff --git a/python-package/xgboost/testing/ordinal.py b/python-package/xgboost/testing/ordinal.py index 404d795951df..0d0ab6c21dfb 100644 --- a/python-package/xgboost/testing/ordinal.py +++ b/python-package/xgboost/testing/ordinal.py @@ -3,13 +3,16 @@ import os import tempfile -from typing import Any, Literal, Tuple, Type +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Literal, Tuple, Type, TypeVar import numpy as np +import pytest from ..compat import import_cupy from ..core import DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix from ..data import _lazy_load_cudf_is_cat +from ..training import train from .data import IteratorForTest, is_pd_cat_dtype, make_categorical @@ -233,3 +236,179 @@ def run_cat_container_iter(device: Literal["cpu", "cuda"]) -> None: for _, v in cats.items(): assert v.null_count == 0 assert len(v) == n_cats + + +def run_cat_predict(device: Literal["cpu", "cuda"]) -> None: + """Basic tests for re-coding during prediction.""" + Df, _ = get_df_impl(device) + + def run_basic(DMatrixT: Type) -> None: + df = Df({"c": ["cdef", "abc", "def"]}, dtype="category") + y = np.array([0, 1, 2]) + + codes = df.c.cat.codes + encoded = np.array([codes.iloc[2], codes.iloc[1]]) # used with the next df + + Xy = DMatrixT(df, y, enable_categorical=True) + booster = train({"device": device}, Xy, num_boost_round=4) + + df = Df({"c": ["def", "abc"]}, dtype="category") + codes = df.c.cat.codes + + predt0 = booster.inplace_predict(df) + predt1 = booster.inplace_predict(encoded) + + assert_allclose(device, predt0, predt1) + + fmat = DMatrixT(df, enable_categorical=True) + predt2 = booster.predict(fmat) + assert_allclose(device, predt0, predt2) + + for dm in (DMatrix, QuantileDMatrix): + run_basic(dm) + + def run_mixed(DMatrixT: Type) -> None: + df = Df({"b": [2, 1, 3], "c": ["cdef", "abc", "def"]}, dtype="category") + y = np.array([0, 1, 2]) + + # used with the next df + b_codes = df.b.cat.codes + np.testing.assert_allclose(np.asarray(b_codes), np.array([1, 0, 2])) + # pick codes of 3, 1 + b_encoded = np.array([b_codes.iloc[2], b_codes.iloc[1]]) + + c_codes = df.c.cat.codes + np.testing.assert_allclose(np.asarray(c_codes), np.array([1, 0, 2])) + # pick codes of "def", "abc" + c_encoded = np.array([c_codes.iloc[2], c_codes.iloc[1]]) + encoded = np.stack([b_encoded, c_encoded], axis=1) + + Xy = DMatrixT(df, y, enable_categorical=True) + booster = train({"device": device}, Xy, num_boost_round=4) + + df = Df({"b": [3, 1], "c": ["def", "abc"]}, dtype="category") + predt0 = booster.inplace_predict(df) + predt1 = booster.inplace_predict(encoded) + assert_allclose(device, predt0, predt1) + + fmat = DMatrixT(df, enable_categorical=True) + predt2 = booster.predict(fmat) + assert_allclose(device, predt0, predt2) + + for dm in (DMatrix, QuantileDMatrix): + run_mixed(dm) + + +def run_cat_invalid(device: Literal["cpu", "cuda"]) -> None: + """Basic tests for invalid inputs.""" + Df, _ = get_df_impl(device) + + def run_invalid(DMatrixT: Type) -> None: + df = Df({"b": [2, 1, 3], "c": ["cdef", "abc", "def"]}, dtype="category") + y = np.array([0, 1, 2]) + + Xy = DMatrixT(df, y, enable_categorical=True) + booster = train({"device": device}, Xy, num_boost_round=4) + df["b"] = df["b"].astype(np.int64) + with pytest.raises(ValueError, match="The data type doesn't match"): + booster.inplace_predict(df) + + Xy = DMatrixT(df, y, enable_categorical=True) + with pytest.raises(ValueError, match="The data type doesn't match"): + booster.predict(Xy) + + for dm in (DMatrix, QuantileDMatrix): + run_invalid(dm) + + +def run_cat_thread_safety(device: Literal["cpu", "cuda"]) -> None: + """Basic tests for thread safety.""" + X, y = make_categorical(2048, 16, 112, onehot=False, cat_ratio=0.5) + Xy = QuantileDMatrix(X, y, enable_categorical=True) + booster = train({"device": device}, Xy, num_boost_round=10) + + def run_thread_safety(DMatrixT: Type) -> bool: + Xy = DMatrixT(X, enable_categorical=True) + predt0 = booster.predict(Xy) + predt1 = booster.inplace_predict(X) + assert_allclose(device, predt0, predt1) + return True + + futures = [] + for dm in (DMatrix, QuantileDMatrix): + with ThreadPoolExecutor(max_workers=10) as e: + for _ in range(10): + fut = e.submit(run_thread_safety, dm) + futures.append(fut) + + for f in futures: + assert f.result() + + +U = TypeVar("U", DMatrix, QuantileDMatrix) + + +def _make_dm(DMatrixT: Type[U], ref: DMatrix, *args: Any, **kwargs: Any) -> U: + if DMatrixT is QuantileDMatrix: + return DMatrixT(*args, ref=ref, enable_categorical=True, **kwargs) + return DMatrixT(*args, enable_categorical=True, **kwargs) + + +def _run_predt( + device: str, + DMatrixT: Type, + pred_contribs: bool, + pred_interactions: bool, + pred_leaf: bool, +) -> None: + Df, _ = get_df_impl(device) + + df = Df({"c": ["cdef", "abc", "def"]}, dtype="category") + y = np.array([0, 1, 2]) + + codes = df.c.cat.codes + encoded = np.array([codes.iloc[2], codes.iloc[1]]) # used with the next df + + Xy = DMatrixT(df, y, enable_categorical=True) + booster = train({"device": device}, Xy, num_boost_round=4) + + df = Df({"c": ["def", "abc"]}, dtype="category") + codes = df.c.cat.codes + + # Contribution + predt0 = booster.predict( + _make_dm(DMatrixT, ref=Xy, data=df), + pred_contribs=pred_contribs, + pred_interactions=pred_interactions, + pred_leaf=pred_leaf, + ) + df = Df({"c": encoded}) + predt1 = booster.predict( + _make_dm(DMatrixT, ref=Xy, data=encoded.reshape(2, 1), feature_names=["c"]), + pred_contribs=pred_contribs, + pred_interactions=pred_interactions, + pred_leaf=pred_leaf, + ) + assert_allclose(device, predt0, predt1) + + +def run_cat_shap(device: Literal["cpu", "cuda"]) -> None: + """Basic tests for SHAP values.""" + + for dm in (DMatrix, QuantileDMatrix): + _run_predt( + device, dm, pred_contribs=True, pred_interactions=False, pred_leaf=False + ) + + for dm in (DMatrix, QuantileDMatrix): + _run_predt( + device, dm, pred_contribs=False, pred_interactions=True, pred_leaf=False + ) + + +def run_cat_leaf(device: Literal["cpu", "cuda"]) -> None: + """Basic tests for leaf prediction.""" + # QuantileDMatrix is not supported by leaf. + _run_predt( + device, DMatrix, pred_contribs=False, pred_interactions=False, pred_leaf=True + ) diff --git a/src/common/error_msg.h b/src/common/error_msg.h index 16652d1958ba..78168e1b1f13 100644 --- a/src/common/error_msg.h +++ b/src/common/error_msg.h @@ -128,5 +128,12 @@ constexpr StringView ZeroCudaMemory() { "support. If you are using other types of memory pool, please consider reserving a " "portion of the GPU memory for XGBoost."; } + +// float64 is not supported by JSON yet. Also, floating point as categories is tricky +// since floating point equality test is inaccurate for most hardware. +constexpr StringView NoFloatCat() { + return "Category index from DataFrame has floating point dtype, consider using strings or " + "integers instead."; +} } // namespace xgboost::error #endif // XGBOOST_COMMON_ERROR_MSG_H_ diff --git a/src/data/adapter.h b/src/data/adapter.h index a9e97b3feb1b..339fbcd90e5d 100644 --- a/src/data/adapter.h +++ b/src/data/adapter.h @@ -16,10 +16,11 @@ #include // for variant #include // for vector -#include "../common/math.h" -#include "../encoder/ordinal.h" // for CatStrArrayView -#include "../encoder/types.h" // for TupToVarT -#include "array_interface.h" // for CategoricalIndexArgTypes +#include "../common/error_msg.h" // for NoFloatCat +#include "../common/math.h" // for CheckNAN +#include "../encoder/ordinal.h" // for CatStrArrayView +#include "../encoder/types.h" // for TupToVarT +#include "array_interface.h" // for CategoricalIndexArgTypes #include "xgboost/base.h" #include "xgboost/data.h" #include "xgboost/logging.h" @@ -627,6 +628,9 @@ template using T = typename decltype(t)::value_type; constexpr bool kKnownType = enc::MemberOf, enc::CatPrimIndexTypes>::value; CHECK(kKnownType) << "Unsupported categorical index type."; + if constexpr (std::is_floating_point_v) { + LOG(FATAL) << error::NoFloatCat(); + } auto span = common::Span{t.Values().data(), t.Size()}; if constexpr (kKnownType) { p_cat_columns->emplace_back(span); diff --git a/src/data/cat_container.cc b/src/data/cat_container.cc index c70b7fc10579..d53eedf70fe9 100644 --- a/src/data/cat_container.cc +++ b/src/data/cat_container.cc @@ -9,8 +9,9 @@ #include // for move #include // for vector -#include "../encoder/types.h" // for Overloaded -#include "xgboost/json.h" // for Json +#include "../common/error_msg.h" // for NoFloatCat +#include "../encoder/types.h" // for Overloaded +#include "xgboost/json.h" // for Json namespace xgboost { CatContainer::CatContainer(enc::HostColumnsView const& df) : CatContainer{} { @@ -39,6 +40,12 @@ CatContainer::CatContainer(enc::HostColumnsView const& df) : CatContainer{} { using T = typename cpu_impl::ViewToStorageImpl>::Type; this->cpu_impl_->columns.emplace_back(); + using ElemT = typename T::value_type; + + if constexpr (std::is_floating_point_v) { + LOG(FATAL) << error::NoFloatCat(); + } + this->cpu_impl_->columns.back().emplace(); auto& v = std::get(this->cpu_impl_->columns.back()); v.resize(values.size()); @@ -54,6 +61,9 @@ CatContainer::CatContainer(enc::HostColumnsView const& df) : CatContainer{} { CHECK(this->HostCanRead()); CHECK_EQ(this->n_total_cats_, df.feature_segments.back()); CHECK_GE(this->n_total_cats_, 0) << "Too many categories."; + if (this->n_total_cats_ > 0) { + CHECK(!this->cpu_impl_->columns.empty()); + } } namespace { @@ -229,17 +239,21 @@ CatContainer::CatContainer() : cpu_impl_{std::make_uniqueCopyCommon(that); } +void CatContainer::Copy(Context const* ctx, CatContainer const& that) { + [[maybe_unused]] auto h_view = that.HostView(); + this->CopyCommon(ctx, that); + this->cpu_impl_->Copy(that.cpu_impl_.get()); +} [[nodiscard]] enc::HostColumnsView CatContainer::HostView() const { return this->HostViewImpl(); } +[[nodiscard]] bool CatContainer::Empty() const { return this->cpu_impl_->columns.empty(); } + void CatContainer::Sort(Context const* ctx) { CHECK(ctx->IsCPU()); auto view = this->HostView(); this->sorted_idx_.HostVector().resize(view.n_total_cats); enc::SortNames(enc::Policy{}, view, this->sorted_idx_.HostSpan()); } - -[[nodiscard]] bool CatContainer::DeviceCanRead() const { return false; } #endif // !defined(XGBOOST_USE_CUDA) } // namespace xgboost diff --git a/src/data/cat_container.cu b/src/data/cat_container.cu index fa5134905f77..64f206528cc6 100644 --- a/src/data/cat_container.cu +++ b/src/data/cat_container.cu @@ -141,20 +141,23 @@ CatContainer::CatContainer(DeviceOrd device, enc::DeviceColumnsView const& df) : if (this->n_total_cats_ > 0) { CHECK(this->DeviceCanRead()); CHECK(!this->HostCanRead()); + CHECK(!this->cu_impl_->columns.empty()); } } CatContainer::~CatContainer() = default; -[[nodiscard]] bool CatContainer::DeviceCanRead() const { return !this->cu_impl_->columns.empty(); } - void CatContainer::Copy(Context const* ctx, CatContainer const& that) { - this->CopyCommon(that); if (ctx->IsCPU()) { - auto h_view = that.HostView(); - CHECK(!h_view.Empty()); + // Pull data to host + [[maybe_unused]] auto h_view = that.HostView(); + this->CopyCommon(ctx, that); this->cpu_impl_->Copy(that.cpu_impl_.get()); + CHECK(!this->DeviceCanRead()); } else { + // Pull data to device + [[maybe_unused]] auto d_view = that.DeviceView(ctx); + this->CopyCommon(ctx, that); auto const& that_impl = that.cu_impl_; this->cu_impl_->columns.resize(that.cu_impl_->columns.size()); @@ -186,17 +189,38 @@ void CatContainer::Copy(Context const* ctx, CatContainer const& that) { col); } this->cu_impl_->columns_v = h_columns_v; + CHECK(this->Empty() || !this->HostCanRead()); + } + if (ctx->IsCPU()) { + CHECK_EQ(this->cpu_impl_->columns_v.size(), that.cpu_impl_->columns_v.size()); + CHECK_EQ(this->cpu_impl_->columns.size(), that.cpu_impl_->columns.size()); + CHECK(this->HostCanRead()); + } else { + CHECK_EQ(this->cu_impl_->columns_v.size(), that.cu_impl_->columns_v.size()); + CHECK_EQ(this->cu_impl_->columns.size(), that.cu_impl_->columns.size()); + CHECK(this->DeviceCanRead()); } + CHECK_EQ(this->Empty(), that.Empty()); + CHECK_EQ(this->NumCatsTotal(), that.NumCatsTotal()); +} + +[[nodiscard]] bool CatContainer::Empty() const { + return this->HostCanRead() ? this->cpu_impl_->columns.empty() : this->cu_impl_->columns.empty(); } void CatContainer::Sort(Context const* ctx) { + if (!this->HasCategorical()) { + return; + } + if (ctx->IsCPU()) { auto view = this->HostView(); + CHECK(!view.Empty()) << view.n_total_cats; this->sorted_idx_.HostVector().resize(view.n_total_cats); - enc::SortNames(enc::Policy{}, view, this->sorted_idx_.HostSpan()); + enc::SortNames(cpu_impl::EncPolicy, view, this->sorted_idx_.HostSpan()); } else { auto view = this->DeviceView(ctx); - CHECK(!view.Empty()) << this->HostView().Size(); + CHECK(!view.Empty()) << view.n_total_cats; this->sorted_idx_.SetDevice(ctx->Device()); this->sorted_idx_.Resize(view.n_total_cats); enc::SortNames(cuda_impl::EncPolicy, view, this->sorted_idx_.DeviceSpan()); @@ -206,21 +230,29 @@ void CatContainer::Sort(Context const* ctx) { [[nodiscard]] enc::HostColumnsView CatContainer::HostView() const { std::lock_guard guard{device_mu_}; if (!this->HostCanRead()) { + this->feature_segments_.ConstHostSpan(); // Lazy copy to host this->cu_impl_->CopyTo(this->cpu_impl_.get()); } + CHECK(this->HostCanRead()); return this->HostViewImpl(); } [[nodiscard]] enc::DeviceColumnsView CatContainer::DeviceView(Context const* ctx) const { CHECK(ctx->IsCUDA()); std::lock_guard guard{device_mu_}; - this->feature_segments_.SetDevice(ctx->Device()); if (!this->DeviceCanRead()) { + this->feature_segments_.SetDevice(ctx->Device()); + this->feature_segments_.ConstDeviceSpan(); // Lazy copy to device auto h_view = this->HostViewImpl(); - CHECK(!h_view.Empty()); this->cu_impl_->CopyFrom(h_view); + CHECK_EQ(this->cu_impl_->columns_v.size(), this->cpu_impl_->columns_v.size()); + CHECK_EQ(this->cu_impl_->columns.size(), this->cpu_impl_->columns.size()); + } + CHECK(this->DeviceCanRead()); + if (this->n_total_cats_ != 0) { + CHECK(!this->cu_impl_->columns_v.empty()); } return {dh::ToSpan(this->cu_impl_->columns_v), this->feature_segments_.ConstDeviceSpan(), this->n_total_cats_}; diff --git a/src/data/cat_container.h b/src/data/cat_container.h index b6ceed1f4219..1990e51a81f8 100644 --- a/src/data/cat_container.h +++ b/src/data/cat_container.h @@ -104,22 +104,37 @@ struct CatContainerImpl; */ class CatContainer { /** - * @brief Implementation of the Copy method, used by both CPU and GPU. + * @brief Implementation of the Copy method, used by both CPU and GPU. Note that this + * method changes the permission in the HostDeviceVector as we need to pull data into + * targeted devices. */ - void CopyCommon(CatContainer const& that) { - this->sorted_idx_.SetDevice(that.sorted_idx_.Device()); + void CopyCommon(Context const* ctx, CatContainer const& that) { + auto device = ctx->Device(); + + that.sorted_idx_.SetDevice(device); + this->sorted_idx_.SetDevice(device); this->sorted_idx_.Resize(that.sorted_idx_.Size()); this->sorted_idx_.Copy(that.sorted_idx_); - this->feature_segments_.SetDevice(that.feature_segments_.Device()); + this->feature_segments_.SetDevice(device); + that.feature_segments_.SetDevice(device); this->feature_segments_.Resize(that.feature_segments_.Size()); this->feature_segments_.Copy(that.feature_segments_); this->n_total_cats_ = that.n_total_cats_; + + if (!device.IsCPU()) { + // Pull to device + this->sorted_idx_.ConstDevicePointer(); + this->feature_segments_.ConstDevicePointer(); + } } [[nodiscard]] enc::HostColumnsView HostViewImpl() const { CHECK_EQ(this->cpu_impl_->columns.size(), this->cpu_impl_->columns_v.size()); + if (this->n_total_cats_ != 0) { + CHECK(!this->cpu_impl_->columns_v.empty()); + } return {common::Span{this->cpu_impl_->columns_v}, this->feature_segments_.ConstHostSpan(), this->n_total_cats_}; } @@ -134,17 +149,21 @@ class CatContainer { void Copy(Context const* ctx, CatContainer const& that); - [[nodiscard]] bool HostCanRead() const { - return !this->cpu_impl_->columns.empty() || this->n_total_cats_ == 0; - } - [[nodiscard]] bool DeviceCanRead() const; + [[nodiscard]] bool HostCanRead() const { return this->feature_segments_.HostCanRead(); } + [[nodiscard]] bool DeviceCanRead() const { return this->feature_segments_.DeviceCanRead(); } // Mostly used for testing. void Push(cpu_impl::ColumnType const& column) { this->cpu_impl_->columns.emplace_back(column); } - - [[nodiscard]] bool Empty() const { return this->cpu_impl_->columns.empty(); } + /** + * @brief Wether the container is initialized at all. If the input is not a DataFrame, + * this method returns True. + */ + [[nodiscard]] bool Empty() const; [[nodiscard]] std::size_t NumFeatures() const { return this->cpu_impl_->columns.size(); } + /** + * @brief The number of categories across all features. + */ [[nodiscard]] std::size_t NumCatsTotal() const { return this->n_total_cats_; } /** @@ -160,10 +179,9 @@ class CatContainer { [[nodiscard]] common::Span RefSortedIndex(Context const* ctx) const { std::lock_guard guard{device_mu_}; if (ctx->IsCPU()) { - CHECK(this->sorted_idx_.HostCanRead()); return this->sorted_idx_.ConstHostSpan(); } else { - CHECK(this->sorted_idx_.DeviceCanRead()); + sorted_idx_.SetDevice(ctx->Device()); return this->sorted_idx_.ConstDeviceSpan(); } } diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 6203435b8c95..672767db92c4 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -109,7 +109,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { explicit CudfAdapter(std::string cuda_interfaces_str) : CudfAdapter{StringView{cuda_interfaces_str}} {} - const CudfAdapterBatch& Value() const override { + [[nodiscard]] CudfAdapterBatch const& Value() const override { CHECK_EQ(batch_.columns_.data(), columns_.data().get()); return batch_; } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 0cdaccad4109..c25bdf8befc7 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -50,10 +50,15 @@ DMatrix* SimpleDMatrix::Slice(common::Span ridxs) { out->Info() = this->Info().Slice(&ctx, h_ridx, h_offset.back()); } out->fmat_ctx_ = this->fmat_ctx_; + + out->Info().Cats()->Copy(&fmat_ctx_, *this->Info().Cats()); return out; } DMatrix* SimpleDMatrix::SliceCol(int num_slices, int slice_id) { + if (this->Cats()->HasCategorical()) { + LOG(FATAL) << "Slicing column is not supported for DataFrames with categorical columns."; + } auto out = new SimpleDMatrix; SparsePage& out_page = *out->sparse_page_; auto const slice_size = info_.num_col_ / num_slices; diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 1436d982bc29..f502f5ee56c8 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -7,7 +7,7 @@ #include "../common/cuda_rt_utils.h" // for CurrentDevice #include "cat_container.h" // for CatContainer -#include "device_adapter.cuh" // for CurrentDevice +#include "device_adapter.cuh" #include "simple_dmatrix.cuh" #include "simple_dmatrix.h" #include "xgboost/context.h" // for Context diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index be726c80b48b..160602549324 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -127,7 +127,7 @@ BatchSet SparsePageDMatrix::GetRowBatches() { } BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { - auto id = MakeCache(this, ".col.page", on_host_, cache_prefix_, &cache_info_); + auto id = MakeCache(this, ".col.page", false, cache_prefix_, &cache_info_); CHECK_NE(this->Info().num_col_, 0); this->InitializeSparsePage(ctx); if (!column_source_) { @@ -141,7 +141,7 @@ BatchSet SparsePageDMatrix::GetColumnBatches(Context const *ctx) { } BatchSet SparsePageDMatrix::GetSortedColumnBatches(Context const *ctx) { - auto id = MakeCache(this, ".sorted.col.page", on_host_, cache_prefix_, &cache_info_); + auto id = MakeCache(this, ".sorted.col.page", false, cache_prefix_, &cache_info_); CHECK_NE(this->Info().num_col_, 0); this->InitializeSparsePage(ctx); if (!sorted_column_source_) { diff --git a/src/encoder/ordinal.h b/src/encoder/ordinal.h index 83269d3c913f..bfb334d29666 100644 --- a/src/encoder/ordinal.h +++ b/src/encoder/ordinal.h @@ -342,6 +342,13 @@ void Recode(ExecPolicy const &policy, HostColumnsView orig_enc, Span* in_gpair, bst_target_t const n_groups = model_.learner_model_param->OutputLength(); monitor_.Start("BoostNewTrees"); + // Define the categories. + if (this->model_.Cats()->Empty() && !p_fmat->Cats()->Empty()) { + auto in_cats = p_fmat->Cats(); + this->model_.Cats()->Copy(this->ctx_, *in_cats); + this->model_.Cats()->Sort(this->ctx_); + } else { + CHECK_EQ(this->model_.Cats()->NumCatsTotal(), p_fmat->Cats()->NumCatsTotal()) + << "A new dataset with different categorical features is used for training an existing " + "model."; + } + predt->predictions.SetDevice(ctx_->Device()); auto out = linalg::MakeTensorView(ctx_, &predt->predictions, p_fmat->Info().num_row_, model_.learner_model_param->OutputLength()); diff --git a/src/gbm/gbtree_model.cc b/src/gbm/gbtree_model.cc index 2edb456c95de..c94c6525fea2 100644 --- a/src/gbm/gbtree_model.cc +++ b/src/gbm/gbtree_model.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023, XGBoost Contributors + * Copyright 2019-2024, XGBoost Contributors */ #include "gbtree_model.h" @@ -132,6 +132,8 @@ void GBTreeModel::SaveModel(Json* p_out) const { std::transform(iteration_indptr.cbegin(), iteration_indptr.cend(), jiteration_indptr.begin(), [](bst_tree_t i) { return Integer{i}; }); out["iteration_indptr"] = Array{std::move(jiteration_indptr)}; + + this->Cats()->Save(&out["cats"]); } void GBTreeModel::LoadModel(Json const& in) { @@ -142,11 +144,11 @@ void GBTreeModel::LoadModel(Json const& in) { auto const& jmodel = get(in); - auto const& trees_json = get(in["trees"]); + auto const& trees_json = get(jmodel.at("trees")); CHECK_EQ(trees_json.size(), param.num_trees); trees.resize(param.num_trees); - auto const& tree_info_json = get(in["tree_info"]); + auto const& tree_info_json = get(jmodel.at("tree_info")); CHECK_EQ(tree_info_json.size(), param.num_trees); tree_info.resize(param.num_trees); @@ -171,6 +173,12 @@ void GBTreeModel::LoadModel(Json const& in) { MakeIndptr(this); } + auto p_cats = std::make_shared(); + auto cat_it = jmodel.find("cats"); + if (cat_it != jmodel.cend()) { + p_cats->Load(cat_it->second); + } + this->cats_ = std::move(p_cats); Validate(*this); } diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index 32fa868638bb..7d7893fb3391 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -1,6 +1,7 @@ /** - * Copyright 2017-2023, XGBoost Contributors - * \file gbtree_model.h + * Copyright 2017-2025, XGBoost Contributors + * + * @file gbtree_model.h */ #ifndef XGBOOST_GBM_GBTREE_MODEL_H_ #define XGBOOST_GBM_GBTREE_MODEL_H_ @@ -19,6 +20,7 @@ #include #include "../common/threading_utils.h" +#include "../data/cat_container.h" // for CatContainer namespace xgboost { @@ -94,7 +96,7 @@ struct GBTreeModel : public Model { void InitTreesToUpdate() { if (trees_to_update.size() == 0u) { - for (auto & tree : trees) { + for (auto& tree : trees) { trees_to_update.push_back(std::move(tree)); } trees.clear(); @@ -146,22 +148,27 @@ struct GBTreeModel : public Model { // model parameter GBTreeModelParam param; /*! \brief vector of trees stored in the model */ - std::vector > trees; + std::vector> trees; /*! \brief for the update process, a place to keep the initial trees */ - std::vector > trees_to_update; + std::vector> trees_to_update; /** - * \brief Group index for trees. + * @brief Group index for trees. */ std::vector tree_info; /** - * \brief Number of trees accumulated for each iteration. + * @brief Number of trees accumulated for each iteration. */ std::vector iteration_indptr{0}; + [[nodiscard]] CatContainer const* Cats() const { return this->cats_.get(); } + [[nodiscard]] CatContainer* Cats() { return this->cats_.get(); } + void Cats(std::shared_ptr cats) { this->cats_ = cats; } + private: /** - * \brief Whether the stack contains multi-target tree. + * @brief Categories in the training data. */ + std::shared_ptr cats_{std::make_shared()}; Context const* ctx_; }; } // namespace gbm diff --git a/src/learner.cc b/src/learner.cc index 34f395beb34b..d45b533396db 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -843,7 +843,7 @@ class LearnerConfiguration : public Learner { } } - void InitEstimation(MetaInfo const& info, linalg::Tensor* base_score) { + void InitEstimation(MetaInfo const& info, linalg::Vector* base_score) { base_score->Reshape(1); collective::ApplyWithLabels(this->Ctx(), info, base_score->Data(), [&] { UsePtr(obj_)->InitEstimation(info, base_score); }); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 9e6289c2b630..d986882a6795 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -20,6 +20,7 @@ #include "../common/math.h" // for CheckNAN #include "../common/threading_utils.h" // for ParallelFor #include "../data/adapter.h" // for ArrayAdapter, CSRAdapter, CSRArrayAdapter +#include "../data/cat_container.h" // for CatContainer #include "../data/gradient_index.h" // for GHistIndexMatrix #include "../data/proxy_dmatrix.h" // for DMatrixProxy #include "../gbm/gbtree_model.h" // for GBTreeModel, GBTreeModelParam @@ -96,11 +97,11 @@ void PredValueByOneTree(RegTree::FVec const &p_feats, MultiTargetTree const &tre } // namespace multi namespace { -void PredictByAllTrees(gbm::GBTreeModel const &model, std::uint32_t const tree_begin, - std::uint32_t const tree_end, std::size_t const predict_offset, +void PredictByAllTrees(gbm::GBTreeModel const &model, bst_tree_t const tree_begin, + bst_tree_t const tree_end, std::size_t const predict_offset, std::vector const &thread_temp, std::size_t const offset, std::size_t const block_size, linalg::MatrixView out_predt) { - for (std::uint32_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { + for (bst_tree_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { auto const &tree = *model.trees.at(tree_id); auto const &cats = tree.GetCategoriesMatrix(); bool has_categorical = tree.HasCategoricalSplit(); @@ -169,28 +170,35 @@ struct DataToFeatVec { } }; -struct SparsePageView : public DataToFeatVec { - bst_idx_t base_rowid; - HostSparsePageView view; +template +class SparsePageView : public DataToFeatVec> { + EncAccessor acc_; + HostSparsePageView const view_; - explicit SparsePageView(SparsePage const *p) : base_rowid{p->base_rowid} { view = p->GetView(); } - [[nodiscard]] std::size_t Size() const { return view.Size(); } + public: + bst_idx_t const base_rowid; + + SparsePageView(SparsePage const *p, EncAccessor &&acc) + : acc_{std::forward(acc)}, view_{p->GetView()}, base_rowid{p->base_rowid} {} + [[nodiscard]] std::size_t Size() const { return view_.Size(); } [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { - auto p_data = view[ridx].data(); + auto p_data = view_[ridx].data(); - for (std::size_t i = 0, n = view[ridx].size(); i < n; ++i) { + for (std::size_t i = 0, n = view_[ridx].size(); i < n; ++i) { auto const &entry = p_data[i]; - out[entry.index] = entry.fvalue; + out[entry.index] = acc_(entry); } - return view[ridx].size(); + return view_[ridx].size(); } }; -struct GHistIndexMatrixView : public DataToFeatVec { +template +class GHistIndexMatrixView : public DataToFeatVec> { private: GHistIndexMatrix const &page_; + EncAccessor acc_; common::Span ft_; std::vector const &ptrs_; @@ -202,8 +210,10 @@ struct GHistIndexMatrixView : public DataToFeatVec { bst_idx_t const base_rowid; public: - GHistIndexMatrixView(GHistIndexMatrix const &_page, common::Span ft) + GHistIndexMatrixView(GHistIndexMatrix const &_page, EncAccessor &&acc, + common::Span ft) : page_{_page}, + acc_{acc}, ft_{ft}, ptrs_{_page.cut.Ptrs()}, mins_{_page.cut.MinValues()}, @@ -232,30 +242,30 @@ struct GHistIndexMatrixView : public DataToFeatVec { fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, bin_idx); } - out[fidx] = fvalue; + out[fidx] = acc_(fvalue, fidx); } }); n_non_missings += n_features; } else { for (bst_feature_t fidx = 0; fidx < n_features; ++fidx) { - float f = std::numeric_limits::quiet_NaN(); + float fvalue = std::numeric_limits::quiet_NaN(); bool is_cat = common::IsCat(ft_, fidx); if (columns_.GetColumnType(fidx) == common::kSparseColumn) { // Special handling for extremely sparse data. Just binary search. auto bin_idx = page_.GetGindex(gridx, fidx); if (bin_idx != -1) { if (is_cat) { - f = values_[bin_idx]; + fvalue = values_[bin_idx]; } else { - f = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, - bin_idx); + fvalue = common::HistogramCuts::NumericBinValue(this->ptrs_, values_, mins_, fidx, + bin_idx); } } } else { - f = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); + fvalue = page_.GetFvalue(ptrs_, values_, mins_, gridx, fidx, is_cat); } - if (!common::CheckNAN(f)) { - out[fidx] = f; + if (!common::CheckNAN(fvalue)) { + out[fidx] = acc_(fvalue, fidx); n_non_missings++; } } @@ -263,17 +273,18 @@ struct GHistIndexMatrixView : public DataToFeatVec { return n_non_missings; } - [[nodiscard]] auto Size() const { return page_.Size(); } + [[nodiscard]] bst_idx_t Size() const { return page_.Size(); } }; -template -class AdapterView : public DataToFeatVec> { +template +class AdapterView : public DataToFeatVec> { Adapter const *adapter_; float missing_; + EncAccessor const &acc_; public: - explicit AdapterView(Adapter const *adapter, float missing) - : adapter_{adapter}, missing_{missing} {} + explicit AdapterView(Adapter const *adapter, float missing, EncAccessor const &acc) + : adapter_{adapter}, missing_{missing}, acc_{acc} {} [[nodiscard]] bst_idx_t DoFill(bst_idx_t ridx, float *out) const { auto const &batch = adapter_->Value(); @@ -282,20 +293,21 @@ class AdapterView : public DataToFeatVec> { for (size_t c = 0; c < row.Size(); ++c) { auto e = row.GetElement(c); if (missing_ != e.value && !common::CheckNAN(e.value)) { - out[e.column_idx] = e.value; + auto fvalue = this->acc_(e); + out[e.column_idx] = fvalue; n_non_missings++; } } return n_non_missings; } - [[nodiscard]] size_t Size() const { return adapter_->NumRows(); } + [[nodiscard]] bst_idx_t Size() const { return adapter_->NumRows(); } bst_idx_t const static base_rowid = 0; // NOLINT }; -template -void PredictBatchByBlockOfRowsKernel(DataView batch, gbm::GBTreeModel const &model, +template +void PredictBatchByBlockOfRowsKernel(DataView const &batch, gbm::GBTreeModel const &model, bst_tree_t tree_begin, bst_tree_t tree_end, std::vector *p_thread_temp, std::int32_t n_threads, @@ -354,6 +366,27 @@ static void InitThreadTemp(int nthread, std::vector *out) { out->resize(nthread, RegTree::FVec()); } } + +auto MakeCatAccessor(Context const *ctx, enc::HostColumnsView const &cats, + gbm::GBTreeModel const &model) { + std::vector mapping(cats.n_total_cats); + auto sorted_idx = model.Cats()->RefSortedIndex(ctx); + auto orig_enc = model.Cats()->HostView(); + enc::Recode(cpu_impl::EncPolicy, orig_enc, sorted_idx, cats, common::Span{mapping}); + auto cats_mapping = enc::MappingView{cats.feature_segments, mapping}; + auto acc = CatAccessor{cats_mapping}; + return std::tuple{acc, std::move(mapping)}; +} + +bool ShouldUseBlock(DMatrix *p_fmat) { + // Threshold to use block-based prediction. + constexpr double kDensityThresh = .5; + bst_idx_t n_samples = p_fmat->Info().num_row_; + bst_idx_t total = std::max(n_samples * p_fmat->Info().num_col_, static_cast(1)); + double density = static_cast(p_fmat->Info().num_nonzero_) / static_cast(total); + bool blocked = density > kDensityThresh; + return blocked; +} } // anonymous namespace /** @@ -412,22 +445,25 @@ class ColumnSplitHelper { void PredictDMatrix(Context const *ctx, DMatrix *p_fmat, std::vector *out_preds) { CHECK(xgboost::collective::IsDistributed()) << "column-split prediction is only supported for distributed training"; + if (this->model_.Cats()->HasCategorical()) { + LOG(FATAL) << "Categorical feature is not yet supported with column-split."; + } for (auto const &batch : p_fmat->GetBatches()) { CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * model_.learner_model_param->num_output_group); - PredictBatchKernel(ctx, SparsePageView{&batch}, out_preds); + PredictBatchKernel(ctx, SparsePageView{&batch, NoOpAccessor{}}, out_preds); } } - void PredictLeaf(Context const* ctx, DMatrix *p_fmat, std::vector *out_preds) { + void PredictLeaf(Context const *ctx, DMatrix *p_fmat, std::vector *out_preds) { CHECK(xgboost::collective::IsDistributed()) << "column-split prediction is only supported for distributed training"; for (auto const &batch : p_fmat->GetBatches()) { CHECK_EQ(out_preds->size(), p_fmat->Info().num_row_ * (tree_end_ - tree_begin_)); - PredictBatchKernel(ctx, SparsePageView{&batch}, - out_preds); + PredictBatchKernel(ctx, SparsePageView{&batch, NoOpAccessor{}}, + out_preds); } } @@ -548,8 +584,8 @@ class ColumnSplitHelper { } } - template - void PredictBatchKernel(Context const* ctx, DataView batch, std::vector *out_preds) { + template + void PredictBatchKernel(Context const *ctx, DataView batch, std::vector *out_preds) { auto const num_group = model_.learner_model_param->num_output_group; // parallel over local batch @@ -646,6 +682,7 @@ class CPUPredictor : public Predictor { if (p_fmat->Info().IsColumnSplit()) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict DMatrix with column split" << MTNotImplemented(); + CHECK(!model.Cats()->HasCategorical()) << "The re-coder doesn't support column split yet."; ColumnSplitHelper helper(this->ctx_->Threads(), model, tree_begin, tree_end); helper.PredictDMatrix(ctx_, p_fmat, out_preds); @@ -653,46 +690,54 @@ class CPUPredictor : public Predictor { } auto const n_threads = this->ctx_->Threads(); - constexpr double kDensityThresh = .5; - size_t total = - std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_, static_cast(1)); - double density = static_cast(p_fmat->Info().num_nonzero_) / static_cast(total); - bool blocked = density > kDensityThresh; + + bool blocked = ShouldUseBlock(p_fmat); std::vector feat_vecs; InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &feat_vecs); - std::size_t n_samples = p_fmat->Info().num_row_; - std::size_t n_groups = model.learner_model_param->OutputLength(); + // Create a writable view on the output prediction vector. + bst_idx_t n_groups = model.learner_model_param->OutputLength(); + bst_idx_t n_samples = p_fmat->Info().num_row_; CHECK_EQ(out_preds->size(), n_samples * n_groups); auto out_predt = linalg::MakeTensorView(ctx_, *out_preds, n_samples, n_groups); - if (!p_fmat->PageExists()) { - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - for (auto const &batch : p_fmat->GetBatches(ctx_, {})) { - if (blocked) { - PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, ft}, model, tree_begin, tree_end, &feat_vecs, n_threads, - out_predt); - } else { - PredictBatchByBlockOfRowsKernel( - GHistIndexMatrixView{batch, ft}, model, tree_begin, tree_end, &feat_vecs, n_threads, - out_predt); + // Dispatching function for various configuration. + auto launch = [&](auto &&acc) { + using Enc = std::remove_reference_t; // The encoder. + if (!p_fmat->PageExists()) { + // Run prediction on QDM. + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (auto const &page : p_fmat->GetBatches(ctx_, {})) { + auto batch = GHistIndexMatrixView{page, std::forward(acc), ft}; + if (blocked) { + PredictBatchByBlockOfRowsKernel(batch, model, tree_begin, tree_end, + &feat_vecs, n_threads, out_predt); + } else { + PredictBatchByBlockOfRowsKernel<1>(batch, model, tree_begin, tree_end, &feat_vecs, + n_threads, out_predt); + } } - } - } else { - for (auto const &batch : p_fmat->GetBatches()) { - if (blocked) { - PredictBatchByBlockOfRowsKernel( - SparsePageView{&batch}, model, tree_begin, tree_end, &feat_vecs, n_threads, - out_predt); - - } else { - PredictBatchByBlockOfRowsKernel(SparsePageView{&batch}, model, - tree_begin, tree_end, &feat_vecs, - n_threads, out_predt); + } else { + // Run prediction on SparsePage + for (auto const &page : p_fmat->GetBatches()) { + auto batch = SparsePageView{&page, std::forward(acc)}; + if (blocked) { + PredictBatchByBlockOfRowsKernel(batch, model, tree_begin, tree_end, + &feat_vecs, n_threads, out_predt); + } else { + PredictBatchByBlockOfRowsKernel<1>(batch, model, tree_begin, tree_end, &feat_vecs, + n_threads, out_predt); + } } } + }; + + if (model.Cats()->HasCategorical() && !p_fmat->Cats()->Empty()) { + auto [acc, mapping] = MakeCatAccessor(ctx_, p_fmat->Cats()->HostView(), model); + launch(acc); + } else { + launch(NoOpAccessor{}); } } @@ -769,9 +814,9 @@ class CPUPredictor : public Predictor { this->PredictDMatrix(dmat, &out_preds->HostVector(), model, tree_begin, tree_end); } - template + template void DispatchedInplacePredict(std::any const &x, std::shared_ptr p_m, - const gbm::GBTreeModel &model, float missing, + gbm::GBTreeModel const &model, float missing, PredictionCacheEntry *out_preds, bst_tree_t tree_begin, bst_tree_t tree_end) const { auto const n_threads = this->ctx_->Threads(); @@ -783,39 +828,59 @@ class CPUPredictor : public Predictor { CHECK_EQ(p_m->Info().num_col_, m->NumColumns()); this->InitOutPredictions(p_m->Info(), &(out_preds->predictions), model); + bool blocked = ShouldUseBlock(p_m.get()); + auto &predictions = out_preds->predictions.HostVector(); std::vector thread_temp; - InitThreadTemp(n_threads * kBlockSize, &thread_temp); - std::size_t n_groups = model.learner_model_param->OutputLength(); + InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1), &thread_temp); + bst_idx_t n_groups = model.learner_model_param->OutputLength(); auto out_predt = linalg::MakeTensorView(ctx_, predictions, m->NumRows(), n_groups); - PredictBatchByBlockOfRowsKernel, kBlockSize>( - AdapterView(m.get(), missing), model, tree_begin, tree_end, &thread_temp, - n_threads, out_predt); + + auto launch = [&](auto &&acc) { + auto view = AdapterView{m.get(), missing, acc}; + if (blocked) { + PredictBatchByBlockOfRowsKernel(view, model, tree_begin, tree_end, + &thread_temp, n_threads, out_predt); + } else { + PredictBatchByBlockOfRowsKernel<1>(view, model, tree_begin, tree_end, &thread_temp, + n_threads, out_predt); + } + }; + + if constexpr (std::is_same_v) { + // Make specialization for DataFrame where we need encoding. + if (model.Cats()->HasCategorical()) { + auto [acc, mapping] = MakeCatAccessor(ctx_, m->Cats(), model); + return launch(acc); + } + } + launch(NoOpAccessor{}); } - bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel &model, float missing, + bool InplacePredict(std::shared_ptr p_m, gbm::GBTreeModel const &model, float missing, PredictionCacheEntry *out_preds, bst_tree_t tree_begin, bst_tree_t tree_end) const override { auto proxy = dynamic_cast(p_m.get()); - CHECK(proxy)<< error::InplacePredictProxy(); + CHECK(proxy) << error::InplacePredictProxy(); CHECK(!p_m->Info().IsColumnSplit()) << "Inplace predict support for column-wise data split is not yet implemented."; - auto x = proxy->Adapter(); + auto const &x = proxy->Adapter(); + if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict( - x, p_m, model, missing, out_preds, tree_begin, tree_end); + this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, + tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, - tree_begin, tree_end); + this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, + tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict( - x, p_m, model, missing, out_preds, tree_begin, tree_end); + this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, + tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, - tree_begin, tree_end); + this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, + tree_begin, tree_end); } else if (x.type() == typeid(std::shared_ptr)) { - this->DispatchedInplacePredict( - x, p_m, model, missing, out_preds, tree_begin, tree_end); + this->DispatchedInplacePredict(x, p_m, model, missing, out_preds, + tree_begin, tree_end); } else { return false; } @@ -834,27 +899,29 @@ class CPUPredictor : public Predictor { if (p_fmat->Info().IsColumnSplit()) { CHECK(!model.learner_model_param->IsVectorLeaf()) << "Predict leaf with column split" << MTNotImplemented(); - + CHECK(!model.Cats()->HasCategorical()) + << "Categorical feature is not yet supported with column-split."; ColumnSplitHelper helper(n_threads, model, 0, ntree_limit); helper.PredictLeaf(ctx_, p_fmat, &preds); return; } std::vector feat_vecs; - const int num_feature = model.learner_model_param->num_feature; + const int n_features = model.learner_model_param->num_feature; InitThreadTemp(n_threads, &feat_vecs); - // start collecting the prediction - for (const auto &batch : p_fmat->GetBatches()) { - // parallel over local batch - auto page = batch.GetView(); + + auto launch = [&](SparsePage const &page, auto &&acc) { + using Enc = std::remove_reference_t; // The encoder. common::ParallelFor(page.Size(), n_threads, [&](auto i) { - const int tid = omp_get_thread_num(); - auto ridx = static_cast(batch.base_rowid + i); + auto tid = omp_get_thread_num(); + auto ridx = static_cast(page.base_rowid + i); RegTree::FVec &feats = feat_vecs[tid]; if (feats.Size() == 0) { - feats.Init(num_feature); + feats.Init(n_features); } - feats.Fill(page[i]); + SparsePageView view{&page, std::forward(acc)}; + view.Fill(i, &feats); + for (bst_tree_t j = 0; j < ntree_limit; ++j) { auto const &tree = *model.trees[j]; auto const &cats = tree.GetCategoriesMatrix(); @@ -868,6 +935,17 @@ class CPUPredictor : public Predictor { } feats.Drop(); }); + }; + + // Start collecting the prediction + for (const auto &batch : p_fmat->GetBatches()) { + // parallel over local batch + if (model.Cats()->HasCategorical() && !p_fmat->Cats()->Empty()) { + auto [acc, mapping] = MakeCatAccessor(ctx_, p_fmat->Cats()->HostView(), model); + launch(batch, std::move(acc)); + } else { + launch(batch, NoOpAccessor{}); + } } } @@ -897,20 +975,30 @@ class CPUPredictor : public Predictor { common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) { FillNodeMeanValues(model.trees[i].get(), &(mean_values[i])); }); - // start collecting the contributions - if (!p_fmat->PageExists()) { - auto ft = p_fmat->Info().feature_types.ConstHostVector(); - for (const auto &batch : p_fmat->GetBatches(ctx_, {})) { - PredictContributionKernel(GHistIndexMatrixView{batch, ft}, info, model, tree_weights, - &mean_values, &feat_vecs, &contribs, ntree_limit, approximate, - condition, condition_feature); + + auto launch = [&](auto &&acc) { + // Start collecting the contributions + using Enc = std::remove_reference_t; // The encoder. + if (!p_fmat->PageExists()) { + auto ft = p_fmat->Info().feature_types.ConstHostVector(); + for (const auto &batch : p_fmat->GetBatches(ctx_, {})) { + PredictContributionKernel(GHistIndexMatrixView{batch, std::forward(acc), ft}, info, + model, tree_weights, &mean_values, &feat_vecs, &contribs, + ntree_limit, approximate, condition, condition_feature); + } + } else { + for (const auto &batch : p_fmat->GetBatches()) { + PredictContributionKernel(SparsePageView{&batch, std::forward(acc)}, info, model, + tree_weights, &mean_values, &feat_vecs, &contribs, ntree_limit, + approximate, condition, condition_feature); + } } + }; + if (model.Cats()->HasCategorical() && !p_fmat->CatsShared()->Empty()) { + auto [acc, mapping] = MakeCatAccessor(ctx_, p_fmat->Cats()->HostView(), model); + launch(acc); } else { - for (const auto &batch : p_fmat->GetBatches()) { - PredictContributionKernel( - SparsePageView{&batch}, info, model, tree_weights, &mean_values, &feat_vecs, - &contribs, ntree_limit, approximate, condition, condition_feature); - } + launch(NoOpAccessor{}); } } @@ -923,8 +1011,8 @@ class CPUPredictor : public Predictor { CHECK(!p_fmat->Info().IsColumnSplit()) << "Predict interaction contribution support for " "column-wise data split is not yet implemented."; const MetaInfo& info = p_fmat->Info(); - const int ngroup = model.learner_model_param->num_output_group; - size_t const ncolumns = model.learner_model_param->num_feature; + auto const ngroup = model.learner_model_param->num_output_group; + auto const ncolumns = model.learner_model_param->num_feature; const unsigned row_chunk = ngroup * (ncolumns + 1) * (ncolumns + 1); const unsigned mrow_chunk = (ncolumns + 1) * (ncolumns + 1); const unsigned crow_chunk = ngroup * (ncolumns + 1); @@ -951,7 +1039,7 @@ class CPUPredictor : public Predictor { tree_weights, approximate, 1, i); for (size_t j = 0; j < info.num_row_; ++j) { - for (int l = 0; l < ngroup; ++l) { + for (std::remove_const_t l = 0; l < ngroup; ++l) { const unsigned o_offset = j * row_chunk + l * mrow_chunk + i * (ncolumns + 1); const unsigned c_offset = j * crow_chunk + l * (ncolumns + 1); contribs[o_offset + i] = 0; @@ -960,7 +1048,8 @@ class CPUPredictor : public Predictor { if (k == i) { contribs[o_offset + i] += contribs_diag[c_offset + k]; } else { - contribs[o_offset + k] = (contribs_on[c_offset + k] - contribs_off[c_offset + k])/2.0; + contribs[o_offset + k] = + (contribs_on[c_offset + k] - contribs_off[c_offset + k]) / 2.0; contribs[o_offset + i] -= contribs[o_offset + k]; } } @@ -970,7 +1059,7 @@ class CPUPredictor : public Predictor { } private: - static size_t constexpr kBlockOfRowsSize = 64; + static std::size_t constexpr kBlockOfRowsSize = 64; }; XGBOOST_REGISTER_PREDICTOR(CPUPredictor, "cpu_predictor") diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h index e3be91d5fa3f..1b00add3e827 100644 --- a/src/predictor/predict_fn.h +++ b/src/predictor/predict_fn.h @@ -8,6 +8,7 @@ #include // for vector #include "../common/categorical.h" // for IsCat, Decision +#include "../data/adapter.h" // for COOTuple #include "xgboost/tree_model.h" // for RegTree namespace xgboost::predictor { @@ -64,5 +65,46 @@ inline bst_tree_t GetTreeLimit(std::vector> const &tree } return ntree_limit; } + +/** + * @brief Accessor for obtaining re-coded categories. + */ +struct CatAccessor { + enc::MappingView enc; + + template + [[nodiscard]] XGBOOST_DEVICE T operator()(T fvalue, Fidx f_idx) const { + if (!enc.Empty() && !enc[f_idx].empty()) { + auto f_mapping = enc[f_idx]; + auto cat_idx = common::AsCat(fvalue); + if (cat_idx >= 0 && cat_idx < common::AsCat(f_mapping.size())) { + fvalue = f_mapping.data()[cat_idx]; + } + } + return fvalue; + } + + [[nodiscard]] XGBOOST_DEVICE float operator()(data::COOTuple const &e) const { + return this->operator()(e.value, e.column_idx); + } + + [[nodiscard]] XGBOOST_DEVICE float operator()(Entry const &e) const { + return this->operator()(e.fvalue, e.index); + } +}; + +/** + * @brief No-op accessor used to handle numeric data. + */ +struct NoOpAccessor { + XGBOOST_DEVICE explicit NoOpAccessor(enc::MappingView const &) {} + NoOpAccessor() = default; + template + [[nodiscard]] XGBOOST_DEVICE T operator()(T fvalue, Fidx) const { + return fvalue; + } + [[nodiscard]] XGBOOST_DEVICE float operator()(data::COOTuple const &e) const { return e.value; } + [[nodiscard]] XGBOOST_DEVICE float operator()(Entry const &e) const { return e.fvalue; } +}; } // namespace xgboost::predictor #endif // XGBOOST_PREDICTOR_PREDICT_FN_H_ diff --git a/tests/cpp/data/test_cat_container.cu b/tests/cpp/data/test_cat_container.cu index 860d386464d7..965135abbe16 100644 --- a/tests/cpp/data/test_cat_container.cu +++ b/tests/cpp/data/test_cat_container.cu @@ -3,10 +3,15 @@ */ #include +#include // for bst_cat_t +#include // for Span -#include "../../../src/common/common.h" +#include // for vector + +#include "../../../src/common/common.h" // for safe_cuda +#include "../../../src/common/threading_utils.h" // for ParallelFor #include "../encoder/df_mock.h" -#include "../helpers.h" +#include "../helpers.h" // for MakeCUDACtx #include "test_cat_container.h" namespace xgboost { @@ -30,4 +35,25 @@ TEST(CatContainer, MixedGpu) { auto ctx = MakeCUDACtx(0); auto df = TestCatContainerMixed(&ctx, eq_check); } + +TEST(CatContainer, ThreadSafety) { + auto ctx = MakeCUDACtx(0); + auto df = DfTest::Make(DfTest::MakeStrs("abc", "bcd", "cde", "ab"), DfTest::MakeInts(2, 2, 3, 0)); + auto h_df = df.View(); + auto cats = test_cat_detail::FromDf(&ctx, h_df); + cats.Sort(&ctx); // not thread safe + + common::ParallelFor(ctx.Threads(), 64, [&](auto i) { + auto sorted_idx = cats.RefSortedIndex(&ctx); + if (i % 2 == 0) { + auto h_cats = cats.HostView(); + ASSERT_EQ(h_cats.n_total_cats, 8); + } else { + auto d_cats = cats.DeviceView(&ctx); + ASSERT_EQ(d_cats.n_total_cats, 8); + } + auto sol = std::vector{3, 0, 1, 2, 3, 0, 1, 2}; + eq_check(sorted_idx, sol); + }); +} } // namespace xgboost diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 340188b23652..0d1a48201ab9 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -34,7 +34,8 @@ class TestGPUUpdatersMulti: ) @settings(deadline=None, max_examples=50, print_blob=True) def test_hist(self, param, num_rounds, dataset): - param["tree_method"] = "gpu_hist" + param["tree_method"] = "hist" + param["device"] = "cuda" param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), num_rounds) note(str(result)) @@ -208,7 +209,8 @@ def test_categorical_ames_housing( dataset = tm.TestDataset( "ames_housing", tm.data.get_ames_housing, "reg:squarederror", "rmse" ) - cat_parameters["tree_method"] = "gpu_hist" + cat_parameters["tree_method"] = "hist" + cat_parameters["device"] = "cuda" results = train_result(cat_parameters, dataset.get_dmat(), 16) tm.non_increasing(results["train"]["rmse"]) @@ -260,7 +262,8 @@ def test_gpu_hist_device_dmatrix( ) -> None: # We cannot handle empty dataset yet assume(len(dataset.y) > 0) - param["tree_method"] = "gpu_hist" + param["tree_method"] = "hist" + param["device"] = "cuda" param = dataset.set_params(param) result = train_result( param, @@ -281,7 +284,8 @@ def test_external_memory(self, param, num_rounds, dataset): return # We cannot handle empty dataset yet assume(len(dataset.y) > 0) - param["tree_method"] = "gpu_hist" + param["tree_method"] = "hist" + param["device"] = "cuda" param = dataset.set_params(param) m = dataset.get_external_dmat() external_result = train_result(param, m, num_rounds) @@ -317,8 +321,10 @@ def test_empty_dmatrix_prediction(self): @pytest.mark.mgpu @given(tm.make_dataset_strategy(), strategies.integers(0, 10)) @settings(deadline=None, max_examples=10, print_blob=True) - def test_specified_gpu_id_gpu_update(self, dataset, gpu_id): - param = {"tree_method": "gpu_hist", "gpu_id": gpu_id} + def test_specified_gpu_id_gpu_update( + self, dataset: tm.TestDataset, gpu_id: int + ) -> None: + param = {"tree_method": "hist", "device": f"cuda:{gpu_id}"} param = dataset.set_params(param) result = train_result(param, dataset.get_dmat(), 10) assert tm.non_increasing(result["train"][dataset.metric]) diff --git a/tests/python-gpu/test_gpu_with_sklearn.py b/tests/python-gpu/test_gpu_with_sklearn.py index a01e79ccc88a..0e5ac2f6d0a7 100644 --- a/tests/python-gpu/test_gpu_with_sklearn.py +++ b/tests/python-gpu/test_gpu_with_sklearn.py @@ -80,13 +80,18 @@ def test_categorical(): from sklearn.datasets import load_svmlight_file data_dir = tm.data_dir(__file__) - X, y = load_svmlight_file(os.path.join(data_dir, "agaricus.txt.train")) + X, y = load_svmlight_file( + os.path.join(data_dir, "agaricus.txt.train"), dtype=np.float32 + ) clf = xgb.XGBClassifier( - tree_method="gpu_hist", + tree_method="hist", + device="cuda", enable_categorical=True, n_estimators=10, ) X = pd.DataFrame(X.todense()).astype("category") + for c in X.columns: + X[c] = X[c].cat.rename_categories(int) clf.fit(X, y) with tempfile.TemporaryDirectory() as tempdir: @@ -105,7 +110,7 @@ def test_categorical(): def check_predt(X, y): reg = xgb.XGBRegressor( - tree_method="gpu_hist", enable_categorical=True, n_estimators=64 + tree_method="hist", enable_categorical=True, n_estimators=64, device="cuda" ) reg.fit(X, y) predts = reg.predict(X) diff --git a/tests/python/test_demos.py b/tests/python/test_demos.py index d20e5bc384cc..9f7bd7123fde 100644 --- a/tests/python/test_demos.py +++ b/tests/python/test_demos.py @@ -228,6 +228,8 @@ def test_cli_regression_demo() -> None: subprocess.check_call(cmd, cwd=reg_dir) exe = os.path.join(DEMO_DIR, os.path.pardir, "xgboost") + if not os.path.exists(exe): + pytest.skip("CLI executable not found.") conf = os.path.join(reg_dir, "machine.conf") subprocess.check_call([exe, conf], cwd=reg_dir) @@ -237,6 +239,9 @@ def test_cli_regression_demo() -> None: ) def test_cli_binary_classification() -> None: cls_dir = os.path.join(CLI_DEMO_DIR, "binary_classification") + exe = os.path.join(DEMO_DIR, os.path.pardir, "xgboost") + if not os.path.exists(exe): + pytest.skip("CLI executable not found.") with tm.DirectoryExcursion(cls_dir, cleanup=True): subprocess.check_call(["./runexp.sh"]) os.remove("0002.model") diff --git a/tests/python/test_ordinal.py b/tests/python/test_ordinal.py index 6863733f2d47..05cd641693a5 100644 --- a/tests/python/test_ordinal.py +++ b/tests/python/test_ordinal.py @@ -5,6 +5,11 @@ run_cat_container, run_cat_container_iter, run_cat_container_mixed, + run_cat_invalid, + run_cat_leaf, + run_cat_predict, + run_cat_shap, + run_cat_thread_safety, ) pytestmark = pytest.mark.skipif(**tm.no_multiple(tm.no_arrow(), tm.no_pandas())) @@ -20,3 +25,23 @@ def test_cat_container_mixed() -> None: def test_cat_container_iter() -> None: run_cat_container_iter("cpu") + + +def test_cat_predict() -> None: + run_cat_predict("cpu") + + +def test_cat_invalid() -> None: + run_cat_invalid("cpu") + + +def test_cat_thread_safety() -> None: + run_cat_thread_safety("cpu") + + +def test_cat_shap() -> None: + run_cat_shap("cpu") + + +def test_cat_leaf() -> None: + run_cat_leaf("cpu") diff --git a/tests/python/test_predict.py b/tests/python/test_predict.py index 4a81e807bfa3..fc330962cde9 100644 --- a/tests/python/test_predict.py +++ b/tests/python/test_predict.py @@ -1,6 +1,7 @@ """Tests for running inplace prediction.""" from concurrent.futures import ThreadPoolExecutor +from typing import List, Union import numpy as np import pandas as pd @@ -251,11 +252,14 @@ def test_dtypes(self) -> None: @pytest.mark.skipif(**tm.no_pandas()) def test_pd_dtypes(self) -> None: + import pandas as pd from pandas.api.types import is_bool_dtype for orig, x in pd_dtypes(): - dtypes = orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes] - if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes[0]): + dtypes: Union[List, pd.Series] = ( + orig.dtypes if isinstance(orig, pd.DataFrame) else [orig.dtypes] + ) + if isinstance(orig, pd.DataFrame) and is_bool_dtype(dtypes.iloc[0]): continue y = np.arange(x.shape[0]) Xy = xgb.DMatrix(orig, y, enable_categorical=True) diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index 90a84a0090c1..0eccf1f46c67 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -1,4 +1,4 @@ -"""Copyright 2019-2024, XGBoost contributors""" +"""Copyright 2019-2025, XGBoost contributors""" import asyncio import json From 18c91fa62078d5b8d48a3f23748412c7e45a4df8 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 15 Mar 2025 13:56:21 +0800 Subject: [PATCH 006/224] Update release script for sdist. (#11337) (#11340) --- .../script/release_artifacts.py | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) rename dev/release-artifacts.py => ops/script/release_artifacts.py (92%) diff --git a/dev/release-artifacts.py b/ops/script/release_artifacts.py similarity index 92% rename from dev/release-artifacts.py rename to ops/script/release_artifacts.py index fc6c0f3b1307..52963fac8ded 100644 --- a/dev/release-artifacts.py +++ b/ops/script/release_artifacts.py @@ -5,39 +5,30 @@ """ import argparse -import os import shutil import subprocess import tarfile import tempfile from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple from urllib.request import urlretrieve import tqdm from packaging import version +from pypi_variants import make_pyproject from sh.contrib import git +from test_utils import PY_PACKAGE +from test_utils import ROOT as root_path +from test_utils import DirectoryExcursion # S3 bucket hosting the release artifacts S3_BUCKET_URL = "https://s3-us-west-2.amazonaws.com/xgboost-nightly-builds" -ROOT = Path(__file__).absolute().parent.parent -DIST = ROOT / "python-package" / "dist" +DIST = Path(PY_PACKAGE) / "dist" +ROOT = Path(root_path) pbar = None -class DirectoryExcursion: - def __init__(self, path: Path) -> None: - self.path = path - self.curdir = Path.cwd().resolve() - - def __enter__(self) -> None: - os.chdir(self.path) - - def __exit__(self, *args: Any) -> None: - os.chdir(self.curdir) - - def show_progress(block_num: int, block_size: int, total_size: int) -> None: """Show file download progress.""" global pbar @@ -118,16 +109,24 @@ def make_python_sdist( dist_dir = outdir / "dist" dist_dir.mkdir(exist_ok=True) - # Apply patch to remove NCCL dependency - # Save the original content of pyproject.toml so that we can restore it later + # Build sdist for `xgboost-cpu`. with DirectoryExcursion(ROOT): - with open("python-package/pyproject.toml", "r") as f: - orig_pyproj_lines = f.read() - with open("ops/patch/remove_nccl_dep.patch", "r") as f: - patch_lines = f.read() - subprocess.run( - ["patch", "-p0"], input=patch_lines, check=True, text=True, encoding="utf-8" + make_pyproject("cpu") + with DirectoryExcursion(ROOT / "python-package"): + subprocess.run(["python", "-m", "build", "--sdist"], check=True) + sdist_name = ( + f"xgboost_cpu-{release}{rc}{rc_ver}.tar.gz" + if rc + else f"xgboost_cpu-{release}.tar.gz" ) + src = DIST / sdist_name + subprocess.run(["twine", "check", str(src)], check=True) + dest = dist_dir / sdist_name + shutil.move(src, dest) + + # Build sdist for `xgboost`. + with DirectoryExcursion(ROOT): + make_pyproject("default") with DirectoryExcursion(ROOT / "python-package"): subprocess.run(["python", "-m", "build", "--sdist"], check=True) @@ -141,10 +140,6 @@ def make_python_sdist( dest = dist_dir / sdist_name shutil.move(src, dest) - with DirectoryExcursion(ROOT): - with open("python-package/pyproject.toml", "w") as f: - f.write(orig_pyproj_lines) - def download_python_wheels(branch: str, commit_hash: str, outdir: Path) -> None: """Download all Python binary wheels for the specified branch.""" @@ -318,6 +313,7 @@ def main(args: argparse.Namespace) -> None: rc_ver: Optional[int] = None else: # RC release + assert release_parsed.pre is not None rc, rc_ver = release_parsed.pre if rc != "rc": raise ValueError( From 429f81279ce2ed6d79c23f884fbf6cb9c926768f Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 15 Mar 2025 18:55:37 +0800 Subject: [PATCH 007/224] 3.0 release note. (#11285) --- doc/changes/index.rst | 1 + doc/changes/v3.0.0.rst | 368 ++++++++++++++++++++++++++++++++++++++++ doc/conf.py | 7 +- doc/requirements.txt | 1 + ops/script/changelog.py | 32 ++++ 5 files changed, 408 insertions(+), 1 deletion(-) create mode 100644 doc/changes/v3.0.0.rst create mode 100644 ops/script/changelog.py diff --git a/doc/changes/index.rst b/doc/changes/index.rst index 09bc215075e4..c1e155ca0421 100644 --- a/doc/changes/index.rst +++ b/doc/changes/index.rst @@ -8,4 +8,5 @@ For release notes prior to the 2.1 release, please see `news ` for more info. +- Optimization for nearly-dense input, see the section for :ref:`optimization + <3_0_optimization>` for more info. + +See our latest document for details :doc:`/tutorials/external_memory`. The PyPI package +(``pip install``) doesn't have ``RMM`` support, which is required by the GPU external +memory implementation. To experiment, you can compile XGBoost from source or wait for the +RAPIDS conda package to be available. + +.. _3_0_networking: + +********** +Networking +********** + +Continuing the work from the previous release, we updated the network module to improve +reliability. (:pr:`10453`, :pr:`10756`, :pr:`11111`, :pr:`10914`, :pr:`10828`, :pr:`10735`, :pr:`10693`, :pr:`10676`, :pr:`10349`, +:pr:`10397`, :pr:`10566`, :pr:`10526`, :pr:`10349`) + +The timeout option is now supported for NCCL using the NCCL asynchronous mode (:pr:`10850`, +:pr:`10934`, :pr:`10945`, :pr:`10930`). + +In addition, a new :py:class:`~xgboost.collective.Config` class is added for users to +specify various options including timeout, tracker port, etc for distributed +training. Both the Dask interface and the PySpark interface support the new +configuration. (:pr:`11003`, :pr:`10281`, :pr:`10983`, :pr:`10973`) + +**** +SYCL +**** + +Continuing the work on the SYCL integration, there are significant improvements in the +feature coverage for this release from more training parameters and more objectives to +distributed training, along with various optimization (:pr:`10884`, :pr:`10883`). + +Starting with 3.0, the SYCL-plugin is close to feature-complete, users can start working +on SYCL devices for in-core training and inference. Newly introduced features include: + +- Dask support for distributed training (:pr:`10812`) + +- Various training procedures, including split evaluation (:pr:`10605`, :pr:`10636`), grow policy + (:pr:`10690`, :pr:`10681`), cached prediction (:pr:`10701`). + +- Updates for objective functions. (:pr:`11029`, :pr:`10931`, :pr:`11016`, :pr:`10993`, :pr:`11064`, :pr:`10325`) + +- On-going work for float32-only devices. (:pr:`10702`) + +Other related PRs (:pr:`10842`, :pr:`10543`, :pr:`10806`, :pr:`10943`, :pr:`10987`, :pr:`10548`, :pr:`10922`, :pr:`10898`, :pr:`10576`) + +.. _3_0_features: + +******** +Features +******** + +This section describes new features in the XGBoost core. For language-specific features, +please visit corresponding sections. + +- A new initialization method for objectives that are derived from GLM. The new method is + based on the mean value of the input labels. The new method changes the result of the + estimated ``base_score``. (:pr:`10298`, :pr:`11331`) + +- The :py:class:`xgboost.QuantileDMatrix` can be used with all prediction types for both + CPU and GPU. + +- In prior releases, XGBoost makes a copy for the booster to release memory held by + internal tree methods. We formalize the procedure into a new booster method + :py:meth:`~xgboost.Booster.reset` / :cpp:func:`XGBoosterReset`. (:pr:`11042`) + +- OpenMP thread setting is exposed to the XGBoost global configuration. Users can use it + to workaround hardcoded OpenMP environment variables. (:pr:`11175`) + +- We improved learning to rank tasks for better hyper-parameter configuration and for + distributed training. + + + In 3.0, all three distributed interfaces, including Dask, Spark, and PySpark, support + sorting the data based on query ID. The option for the + :py:class:`~xgboost.dask.DaskXGBRanker` is true by default and can be opted + out. (:pr:`11146`, :pr:`11007`, :pr:`11047`, :pr:`11012`, :pr:`10823`, :pr:`11023`) + + + Also for learning to rank, a new parameter ``lambdarank_score_normalization`` is + introduced to make one of the normalizations optional. (:pr:`11272`) + + + The ``lambdarank_normalization`` now uses the number of pairs when normalizing the + ``mean`` pair strategy. Previously, the gradient was used for both ``topk`` and + ``mean``. :pr:`11322` + +- We have improved GPU quantile sketching to reduce memory usage. The improvement helps + the construction of the :py:class:`~xgboost.QuantileDMatrix` and the new + :py:class:`~xgboost.ExtMemQuantileDMatrix`. + + + A new multi-level sketching algorithm is employed to reduce the overall memory usage + with batched inputs. + + In addition to algorithmic changes, internal memory usage estimation and the quantile + container is also updated. (:pr:`10761`, :pr:`10843`) + + The change introduces two more parameters for the :py:class:`~xgboost.QuantileDMatrix` + and :py:class:`~xgboost.DataIter`, namely, ``max_quantile_batches`` and + ``min_cache_page_bytes``. + +- More work is needed to improve the support of categorical features. This release + supports plotting trees with stat for categorical nodes (:pr:`11053`). In addition, some + preparation work is ongoing for auto re-coding categories. (:pr:`11094`, :pr:`11114`, + :pr:`11089`) These are feature enhancements instead of blocking issues. +- Implement weight-based feature importance for vector-leaf. (:pr:`10700`) +- Reduced logging in the DMatrix construction. (:pr:`11080`) + +.. _3_0_optimization: + +************ +Optimization +************ + +In addition to the external memory and quantile sketching improvements, we have a number +of optimizations and performance fixes. + +- GPU tree methods now use significantly less memory for both dense inputs and near-dense + inputs. (:pr:`10821`, :pr:`10870`) +- For near-dense inputs, GPU training is much faster for both ``hist`` (about 2x) and + ``approx``. +- Quantile regression on CPU now can handle imbalance trees much more efficiently. (:pr:`11275`) +- Small optimization for DMatrix construction to reduce latency. Also, C users can now + reuse the :cpp:func:`ProxyDMatrix ` for multiple inference + calls. (:pr:`11273`) +- CPU prediction performance for :py:class:`~xgboost.QuantileDMatrix` has been improved + (:pr:`11139`) and now is on par with normal ``DMatrix``. +- Fixed a performance issue for running inference using CPU with extremely sparse + :py:class:`~xgboost.QuantileDMatrix` (:pr:`11250`). +- Optimize CPU training memory allocation for improved performance. (:pr:`11112`) +- Improved RMM (rapids memory manager) integration. Now, with the help of + :py:func:`~xgboost.config_context`, all memory allocated by XGBoost should be routed to + RMM. As a bonus, all ``thrust`` algorithms now use async policy. (:pr:`10873`, :pr:`11173`, :pr:`10712`, + :pr:`10712`, :pr:`10562`) +- When used without RMM, XGBoost is more careful with its use of caching allocator to + avoid holding too much device memory. (:pr:`10582`) + +**************** +Breaking Changes +**************** +This section lists breaking changes that affect all packages. + +- Remove the deprecated ``DeviceQuantileDMatrix``. (:pr:`10974`, :pr:`10491`) +- Support for saving the model in the ``deprecated`` has been removed. Users can still + load old models in 3.0. (:pr:`10490`) +- Support for the legacy (blocking) CUDA stream is removed (:pr:`10607`) + +********* +Bug Fixes +********* +- Fix the quantile error metric (pinball loss) with multiple quantiles. (:pr:`11279`) +- Fix potential access error when running prediction in multi-thread environment. (:pr:`11167`) +- Check the correct dump format for the ``gblinear``. (:pr:`10831`) + +************* +Documentation +************* +- A new tutorial for advanced usage with custom objective functions. (:pr:`10283`, :pr:`10725`) +- The new online document site now shows documents for all packages including Python, R, + and JVM-based packages. (:pr:`11240`, :pr:`11216`, :pr:`11166`) +- Lots of enhancements. (:pr:`10822`, 11137, :pr:`11138`, :pr:`11246`, :pr:`11266`, :pr:`11253`, :pr:`10731`, :pr:`11222`, + :pr:`10551`, :pr:`10533`) +- Consistent use of cmake in documents. (:pr:`10717`) +- Add a brief description for using the ``offset`` from the GLM setting (like + ``Poisson``). (:pr:`10996`) +- Cleanup document for building from source. (:pr:`11145`) +- Various fixes. (:pr:`10412`, :pr:`10405`, :pr:`10353`, :pr:`10464`, :pr:`10587`, :pr:`10350`, :pr:`11131`, :pr:`10815`) +- Maintenance. (:pr:`11052`, :pr:`10380`) + +************** +Python Package +************** + +- The ``feature_weights`` parameter in the sklearn interface is now defined as + a scikit-learn parameter. (:pr:`9506`) +- Initial support for polars, categorical feature is not yet supported. (:pr:`11126`, :pr:`11172`, + :pr:`11116`) +- Reduce pandas dataframe overhead and overhead for various imports. (:pr:`11058`, :pr:`11068`) +- Better xlabel in :py:func:`~xgboost.plot_importance` (:pr:`11009`) +- Validate reference dataset for training. The :py:func:`~xgboost.train` function now + throws an error if a :py:class:`~xgboost.QuantileDMatrix` is used as a validation + dataset without a reference. (:pr:`11105`) +- Fix misleading errors when feature names are missing during inference (:pr:`10814`) +- Add Stacklevel to Python warning callback. The change helps improve the error message + for the Python package. (:pr:`10977`) +- Remove circular reference in DataIter. It helps reduce memory usage. (:pr:`11177`) +- Add checks for invalid inputs for `cv`. (:pr:`11255`) +- Update Python project classifiers. (:pr:`10381`, :pr:`11028`) +- Support doc link for the sklearn module. Users can now find links to documents in a + jupyter notebook. (:pr:`10287`) + +- Dask + + + Prevent the training from hanging due to aborted workers. (:pr:`10985`) This helps + Dask XGBoost be robust against error. When a worker is killed, the training will fail + with an exception instead of hang. + + Optional support for client-side logging. (:pr:`10942`) + + Fix LTR with empty partition and NCCL error. (:pr:`11152`) + + Update to work with the latest Dask. (:pr:`11291`) + + See the :ref:`3_0_features` section for changes to ranking models. + + See the :ref:`3_0_networking` section for changes with the communication module. + +- PySpark + + + Expose Training and Validation Metrics. (:pr:`11133`) + + Add barrier before initializing the communicator. (:pr:`10938`) + + Extend support for columnar input to CPU (GPU-only previously). (:pr:`11299`) + + See the :ref:`3_0_features` section for changes to ranking models. + + See the :ref:`3_0_networking` section for changes with the communication module. + +- Document updates (:pr:`11265`). +- Maintenance. (:pr:`11071`, :pr:`11211`, :pr:`10837`, :pr:`10754`, :pr:`10347`, :pr:`10678`, :pr:`11002`, :pr:`10692`, :pr:`11006`, + :pr:`10972`, :pr:`10907`, :pr:`10659`, :pr:`10358`, :pr:`11149`, :pr:`11178`, :pr:`11248`) + +- Breaking changes + + + Remove deprecated `feval`. (:pr:`11051`) + + Remove dask from the default import. (:pr:`10935`) Users are now required to import the + XGBoost Dask through: + + .. code-block:: python + + from xgboost import dask as dxgb + + instead of: + + .. code-block:: python + + import xgboost as xgb + xgb.dask + + The change helps avoid introducing dask into the default import set. + + + Bump Python requirement to 3.10. (:pr:`10434`) + + Drop support for datatable. (:pr:`11070`) + +********* +R Package +********* + +We have been reworking the R package for a few releases now. In 3.0, we will start +publishing a new R package on public repositories, likely R-universe, before moving toward +a CRAN update. The new package features a much more ergonomic interface, which is also +more idiomatic to R speakers. In addition, a range of new features are introduced to the +package. To name a few, the new package includes categorical feature support, +``QuantileDMatrix``, and an initial implementation of the external memory training. + +Also, we finally have an online documentation site for the R package featuring both +vignettes and API references (:pr:`11166`, :pr:`11257`). A good starting point for the new interface +is the new ``xgboost()`` function. We won't list all the feature gains here, as there are +too many! Please visit the :doc:`/R-package/index` for more info. There's a migration +guide (:pr:`11197`) there if you use a previous XGBoost R package version. + +- Support for the MSVC build was dropped due to incompatibility with R headers. (:pr:`10355`, + :pr:`11150`) +- Maintenance (:pr:`11259`) +- Related PRs. (:pr:`11171`, :pr:`11231`, :pr:`11223`, :pr:`11073`, :pr:`11224`, :pr:`11076`, :pr:`11084`, :pr:`11081`, + :pr:`11072`, :pr:`11170`, :pr:`11123`, :pr:`11168`, :pr:`11264`, :pr:`11140`, :pr:`11117`, :pr:`11104`, :pr:`11095`, :pr:`11125`, :pr:`11124`, + :pr:`11122`, :pr:`11108`, :pr:`11102`, :pr:`11101`, :pr:`11100`, :pr:`11077`, :pr:`11099`, :pr:`11074`, :pr:`11065`, :pr:`11092`, :pr:`11090`, + :pr:`11096`, :pr:`11148`, :pr:`11151`, :pr:`11159`, :pr:`11204`, :pr:`11254`, :pr:`11109`, :pr:`11141`, :pr:`10798`, :pr:`10743`, :pr:`10849`, + :pr:`10747`, :pr:`11022`, :pr:`10989`, :pr:`11026`, :pr:`11060`, :pr:`11059`, :pr:`11041`, :pr:`11043`, :pr:`11025`, :pr:`10674`, :pr:`10727`, + :pr:`10745`, :pr:`10733`, :pr:`10750`, :pr:`10749`, :pr:`10744`, :pr:`10794`, :pr:`10330`, :pr:`10698`, :pr:`10687`, :pr:`10688`, :pr:`10654`, + :pr:`10456`, :pr:`10556`, :pr:`10465`, :pr:`10337`) + +************ +JVM Packages +************ + +The XGBoost 3.0 release features a significant update to the JVM packages, and in +particular, the Spark package. There are breaking changes in packaging and some +parameters. Please visit the :doc:`migration guide ` for +related changes. The work brings new features and a more unified feature set between CPU +and GPU implementation. (:pr:`10639`, :pr:`10833`, :pr:`10845`, :pr:`10847`, :pr:`10635`, :pr:`10630`, :pr:`11179`, :pr:`11184`) + +- Automatic partitioning for distributed learning to rank. See the :ref:`features + <3_0_features>` section above (:pr:`11023`). +- Resolve spark compatibility issue (:pr:`10917`) +- Support missing value when constructing dmatrix with iterator (:pr:`10628`) +- Fix transform performance issue (:pr:`10925`) +- Honor skip.native.build option in xgboost4j-gpu (:pr:`10496`) +- Support array features type for CPU (:pr:`10937`) +- Change default missing value to ``NaN`` for better alignment (:pr:`11225`) +- Don't cast to float if it's already float (:pr:`10386`) +- Maintenance. (:pr:`10982`, :pr:`10979`, :pr:`10978`, :pr:`10673`, :pr:`10660`, :pr:`10835`, :pr:`10836`, :pr:`10857`, :pr:`10618`, + :pr:`10627`) + +*********** +Maintenance +*********** + +Code maintenance includes both refactoring (:pr:`10531`, :pr:`10573`, :pr:`11069`), cleanups (:pr:`11129`, +:pr:`10878`, :pr:`11244`, :pr:`10401`, :pr:`10502`, :pr:`11107`, :pr:`11097`, :pr:`11130`, :pr:`10758`, :pr:`10923`, :pr:`10541`, :pr:`10990`), +and improvements for tests (:pr:`10611`, :pr:`10658`, :pr:`10583`, :pr:`11245`, :pr:`10708`), along with fixing +various warnings in compilers and test dependencies (:pr:`10757`, :pr:`10641`, :pr:`11062`, +:pr:`11226`). Also, miscellaneous updates, including some dev scripts and profiling annotations +(:pr:`10485`, :pr:`10657`, :pr:`10854`, :pr:`10718`, :pr:`11158`, :pr:`10697`, :pr:`11276`). + +Lastly, dependency updates (:pr:`10362`, :pr:`10363`, :pr:`10360`, :pr:`10373`, :pr:`10377`, :pr:`10368`, :pr:`10369`, +:pr:`10366`, :pr:`11032`, :pr:`11037`, :pr:`11036`, :pr:`11035`, :pr:`11034`, :pr:`10518`, :pr:`10536`, :pr:`10586`, :pr:`10585`, :pr:`10458`, +:pr:`10547`, :pr:`10429`, :pr:`10517`, :pr:`10497`, :pr:`10588`, :pr:`10975`, :pr:`10971`, :pr:`10970`, :pr:`10949`, :pr:`10947`, :pr:`10863`, +:pr:`10953`, :pr:`10954`, :pr:`10951`, :pr:`10590`, :pr:`10600`, :pr:`10599`, :pr:`10535`, :pr:`10516`, :pr:`10786`, :pr:`10859`, :pr:`10785`, +:pr:`10779`, :pr:`10790`, :pr:`10777`, :pr:`10855`, :pr:`10848`, :pr:`10778`, :pr:`10772`, :pr:`10771`, :pr:`10862`, :pr:`10952`, :pr:`10768`, +:pr:`10770`, :pr:`10769`, :pr:`10664`, :pr:`10663`, :pr:`10892`, :pr:`10979`, :pr:`10978`). + +*** +CI +*** + +- The CI is reworked to use `RunsOn` to integrate custom CI pipelines with GitHub + action. The migration helps us reduce the maintenance burden and make the CI + configuration more accessible to others. (:pr:`11001`, :pr:`11079`, :pr:`10649`, :pr:`11196`, :pr:`11055`, + :pr:`10483`, :pr:`11078`, :pr:`11157`) + +- Other maintenance work includes various small fixes, enhancements, and tooling + updates. (:pr:`10877`, :pr:`10494`, :pr:`10351`, :pr:`10609`, :pr:`11192`, :pr:`11188`, :pr:`11142`, :pr:`10730`, :pr:`11066`, + :pr:`11063`, :pr:`10800`, :pr:`10995`, :pr:`10858`, :pr:`10685`, :pr:`10593`, :pr:`11061`) diff --git a/doc/conf.py b/doc/conf.py index ce6a0219ccb1..6c5c456ac9f8 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -225,7 +225,7 @@ def is_readthedocs_build(): # General information about the project. project = "xgboost" author = "%s developers" % project -copyright = "2022, %s" % author +copyright = "2025, %s" % author github_doc_root = "https://github.com/dmlc/xgboost/tree/master/doc/" # Add any Sphinx extension module names here, as strings. They can be @@ -238,6 +238,7 @@ def is_readthedocs_build(): "sphinx.ext.mathjax", "sphinx.ext.intersphinx", "sphinx_gallery.gen_gallery", + "sphinx_issues", "breathe", "myst_parser", ] @@ -262,6 +263,10 @@ def is_readthedocs_build(): "matplotlib_animations": True, } +# Sphinx-issues configuration +# Path to GitHub repo {group}/{project} (note that `group` is the GitHub user or organization) +issues_github_path = "dmlc/xgboost" + autodoc_typehints = "description" graphviz_output_format = "png" diff --git a/doc/requirements.txt b/doc/requirements.txt index 9a2097035228..d73e5bdf2b84 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -11,6 +11,7 @@ scipy myst-parser ray[train] sphinx-gallery +sphinx-issues dask pyspark cloudpickle diff --git a/ops/script/changelog.py b/ops/script/changelog.py new file mode 100644 index 000000000000..552a82f2e49d --- /dev/null +++ b/ops/script/changelog.py @@ -0,0 +1,32 @@ +"""Helper script for creating links to PRs for changelog. This should be used with the +`sphinx-issues` extension. + +""" + +import argparse +import os +import re + +from test_utils import ROOT + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--version", + type=str, + required=True, + help="Major version of the changelog, e.g., 3.0.0 .", + ) + args = parser.parse_args() + version = args.version + + fname = os.path.join(ROOT, f"doc/changes/v{version}.rst") + + with open(fname) as fd: + note = fd.read() + + # E.g. #11285 -> :pr:`11285`. + regex = re.compile(r"(#)(\d+)") + note = re.sub(regex, r":pr:`\2`", note) + with open(fname, "w") as fd: + fd.write(note) From 257b87ca949b3d622886e5cdbf23c66eaa9784d4 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 17 Mar 2025 13:54:21 +0800 Subject: [PATCH 008/224] [EM] Fix page concatenation for validation dataset. (#11338) --- src/data/ellpack_page_source.cu | 37 ++++++----- src/data/ellpack_page_source.h | 7 ++- .../cpp/data/test_ellpack_page_raw_format.cu | 61 +++++++++++++++++++ 3 files changed, 90 insertions(+), 15 deletions(-) diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 1b839e89df15..8dbf2d3ec696 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -104,23 +104,29 @@ class EllpackHostCacheStreamImpl { this->cache_->sizes_orig.push_back(page.Impl()->MemCostBytes()); auto orig_ptr = this->cache_->sizes_orig.size() - 1; + CHECK_EQ(this->cache_->pages.size(), this->cache_->on_device.size()); CHECK_LT(orig_ptr, this->cache_->NumBatchesOrig()); auto cache_idx = this->cache_->cache_mapping.at(orig_ptr); // Wrap up the previous page if this is a new page, or this is the last page. auto new_page = cache_idx == this->cache_->pages.size(); - + // Last page expected from the user. auto last_page = (orig_ptr + 1) == this->cache_->NumBatchesOrig(); - // No page concatenation is performed. If there's page concatenation, then the number - // of pages in the cache must be smaller than the input number of pages. - bool no_concat = this->cache_->NumBatchesOrig() == this->cache_->buffer_rows.size(); + + bool const no_concat = this->cache_->NoConcat(); + // Whether the page should be cached in device. If true, then we don't need to make a // copy during write since the temporary page is already in device when page // concatenation is enabled. - bool to_device = this->cache_->prefer_device && - this->cache_->NumDevicePages() < this->cache_->max_num_device_pages; - - auto commit_page = [&ctx](EllpackPageImpl const* old_impl) { + // + // This applies only to a new cached page. If we are concatenating this page to an + // existing cached page, then we should respect the existing flag obtained from the + // first page of the cached page. + bool to_device_if_new_page = + this->cache_->prefer_device && + this->cache_->NumDevicePages() < this->cache_->max_num_device_pages; + + auto commit_host_page = [](EllpackPageImpl const* old_impl) { CHECK_EQ(old_impl->gidx_buffer.Resource()->Type(), common::ResourceHandler::kCudaMalloc); auto new_impl = std::make_unique(); new_impl->CopyInfo(old_impl); @@ -137,7 +143,7 @@ class EllpackHostCacheStreamImpl { auto new_impl = std::make_unique(); new_impl->CopyInfo(page.Impl()); - if (to_device) { + if (to_device_if_new_page) { // Copy to device new_impl->gidx_buffer = common::MakeFixedVecWithCudaMalloc( page.Impl()->gidx_buffer.size()); @@ -151,15 +157,16 @@ class EllpackHostCacheStreamImpl { this->cache_->offsets.push_back(new_impl->n_rows * new_impl->info.row_stride); this->cache_->pages.push_back(std::move(new_impl)); + this->cache_->on_device.push_back(to_device_if_new_page); return new_page; } if (new_page) { // No need to copy if it's already in device. - if (!this->cache_->pages.empty() && !to_device) { + if (!this->cache_->pages.empty() && !this->cache_->on_device.back()) { // Need to wrap up the previous page. - auto commited = commit_page(this->cache_->pages.back().get()); - // Replace the previous page with a new page. + auto commited = commit_host_page(this->cache_->pages.back().get()); + // Replace the previous page (on device) with a new page on host. this->cache_->pages.back() = std::move(commited); } // Push a new page @@ -174,7 +181,9 @@ class EllpackHostCacheStreamImpl { auto offset = new_impl->Copy(&ctx, impl, 0); this->cache_->offsets.push_back(offset); + this->cache_->pages.push_back(std::move(new_impl)); + this->cache_->on_device.push_back(to_device_if_new_page); } else { CHECK(!this->cache_->pages.empty()); CHECK_EQ(cache_idx, this->cache_->pages.size() - 1); @@ -182,8 +191,8 @@ class EllpackHostCacheStreamImpl { auto offset = new_impl->Copy(&ctx, impl, this->cache_->offsets.back()); this->cache_->offsets.back() += offset; // No need to copy if it's already in device. - if (last_page && !to_device) { - auto commited = commit_page(this->cache_->pages.back().get()); + if (last_page && !this->cache_->on_device.back()) { + auto commited = commit_host_page(this->cache_->pages.back().get()); this->cache_->pages.back() = std::move(commited); } } diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index cb921daa446f..a668c39bdef4 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -1,5 +1,5 @@ /** - * Copyright 2019-2024, XGBoost Contributors + * Copyright 2019-2025, XGBoost Contributors */ #ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ @@ -47,6 +47,7 @@ struct EllpackCacheInfo { // This is a memory-based cache. It can be a mixed of the device memory and the host memory. struct EllpackMemCache { std::vector> pages; + std::vector on_device; std::vector offsets; // Size of each batch before concatenation. std::vector sizes_orig; @@ -65,6 +66,9 @@ struct EllpackMemCache { [[nodiscard]] std::size_t SizeBytes() const; [[nodiscard]] bool Empty() const { return this->SizeBytes() == 0; } + // No page concatenation is performed. If there's page concatenation, then the number of + // pages in the cache must be smaller than the input number of pages. + [[nodiscard]] bool NoConcat() const { return this->NumBatchesOrig() == this->buffer_rows.size(); } [[nodiscard]] bst_idx_t NumBatchesOrig() const { return cache_mapping.size(); } [[nodiscard]] EllpackPageImpl const* At(std::int32_t k) const; @@ -187,6 +191,7 @@ class EllpackCacheStreamPolicy : public F { [[nodiscard]] std::unique_ptr CreateReader(StringView name, bst_idx_t offset, bst_idx_t length) const; + std::shared_ptr Share() const { return p_cache_; } }; template typename F> diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index 32f0bed1e016..2351089f6f4d 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -157,4 +157,65 @@ TEST_P(TestEllpackPageRawFormat, HostIO) { } INSTANTIATE_TEST_SUITE_P(EllpackPageRawFormat, TestEllpackPageRawFormat, ::testing::Bool()); + +TEST(EllpackPageRawFormat, DevicePageConcat) { + auto ctx = MakeCUDACtx(0); + auto param = BatchParam{256, tree::TrainParam::DftSparseThreshold()}; + bst_idx_t n_features = 16, n_samples = 128; + + auto test = [&](std::int32_t max_num_device_pages, std::int64_t min_cache_page_bytes) { + EllpackCacheInfo cinfo{param, true, max_num_device_pages, + std::numeric_limits::quiet_NaN()}; + ExternalDataInfo ext_info; + + ext_info.n_batches = 8; + ext_info.row_stride = n_features; + for (bst_idx_t i = 0; i < ext_info.n_batches; ++i) { + ext_info.base_rowids.push_back(n_samples); + } + std::partial_sum(ext_info.base_rowids.cbegin(), ext_info.base_rowids.cend(), + ext_info.base_rowids.begin()); + ext_info.accumulated_rows = n_samples * ext_info.n_batches; + ext_info.nnz = ext_info.accumulated_rows * n_features; + + auto p_fmat = RandomDataGenerator{n_samples, n_features, 0}.Seed(0).GenerateDMatrix(); + EllpackCacheStreamPolicy policy; + + for (auto const &page : p_fmat->GetBatches(&ctx, param)) { + auto cuts = page.Impl()->CutsShared(); + CalcCacheMapping(&ctx, true, cuts, min_cache_page_bytes, ext_info, &cinfo); + [&] { + ASSERT_EQ(cinfo.buffer_rows.size(), 4ul); + }(); + policy.SetCuts(page.Impl()->CutsShared(), ctx.Device(), std::move(cinfo)); + } + + auto format = policy.CreatePageFormat(param); + + // write multipe pages + for (bst_idx_t i = 0; i < ext_info.n_batches; ++i) { + for (auto const &page : p_fmat->GetBatches(&ctx, param)) { + auto writer = policy.CreateWriter({}, i); + [[maybe_unused]] auto n_bytes = format->Write(page, writer.get()); + } + } + // check correct concatenation. + auto mem_cache = policy.Share(); + return mem_cache; + }; + + { + auto mem_cache = test(1, n_features * n_samples); + ASSERT_EQ(mem_cache->on_device.size(), 4); + ASSERT_TRUE(mem_cache->on_device[0]); + ASSERT_EQ(mem_cache->NumDevicePages(), 1); + } + { + auto mem_cache = test(2, n_features * n_samples); + ASSERT_EQ(mem_cache->on_device.size(), 4); + ASSERT_TRUE(mem_cache->on_device[0]); + ASSERT_TRUE(mem_cache->on_device[1]); + ASSERT_EQ(mem_cache->NumDevicePages(), 2); + } +} } // namespace xgboost::data From 7bd21832a82dd659357ce7179838c2d8dd2322bb Mon Sep 17 00:00:00 2001 From: Philip Hyunsu Cho Date: Wed, 19 Mar 2025 07:36:04 -0700 Subject: [PATCH 009/224] [Doc] CUDA 12.0+ is now required (#11344) --- doc/changes/v3.0.0.rst | 1 + doc/gpu/index.rst | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/changes/v3.0.0.rst b/doc/changes/v3.0.0.rst index e236ba3132f7..bc1722ad4f2a 100644 --- a/doc/changes/v3.0.0.rst +++ b/doc/changes/v3.0.0.rst @@ -196,6 +196,7 @@ This section lists breaking changes that affect all packages. - Support for saving the model in the ``deprecated`` has been removed. Users can still load old models in 3.0. (:pr:`10490`) - Support for the legacy (blocking) CUDA stream is removed (:pr:`10607`) +- XGBoost now requires CUDA 12.0 or later. ********* Bug Fixes diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index 9603a628cb81..515939723e49 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -4,7 +4,7 @@ XGBoost GPU Support This page contains information about GPU algorithms supported in XGBoost. -.. note:: CUDA 11.0, Compute Capability 5.0 required (See `this list `_ to look up compute capability of your GPU card.) +.. note:: CUDA 12.0, Compute Capability 5.0 required (See `this list `_ to look up compute capability of your GPU card.) ********************************************* CUDA Accelerated Tree Construction Algorithms From 62bae0fcb835c44ac2e9e8b56ad62886635ffbf2 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 20 Mar 2025 01:32:10 +0800 Subject: [PATCH 010/224] Update loky to 3.5.1. (#11341) --- ops/conda_env/aarch64_test.yml | 2 +- ops/conda_env/linux_cpu_test.yml | 2 +- ops/conda_env/macos_cpu_test.yml | 11 ++--------- ops/conda_env/win64_test.yml | 2 +- tests/python/test_collective.py | 6 +++++- 5 files changed, 10 insertions(+), 13 deletions(-) diff --git a/ops/conda_env/aarch64_test.yml b/ops/conda_env/aarch64_test.yml index 14305ebbf090..d7dd13639ff3 100644 --- a/ops/conda_env/aarch64_test.yml +++ b/ops/conda_env/aarch64_test.yml @@ -26,7 +26,7 @@ dependencies: - awscli - numba - llvmlite -- loky +- loky>=3.5.1 - pyarrow - pyspark>=3.4.0 - cloudpickle diff --git a/ops/conda_env/linux_cpu_test.yml b/ops/conda_env/linux_cpu_test.yml index e4c0b507c8e2..55bac17f2dbb 100644 --- a/ops/conda_env/linux_cpu_test.yml +++ b/ops/conda_env/linux_cpu_test.yml @@ -34,7 +34,7 @@ dependencies: - boto3 - awscli - py-ubjson -- loky +- loky>=3.5.1 - pyarrow - protobuf - cloudpickle diff --git a/ops/conda_env/macos_cpu_test.yml b/ops/conda_env/macos_cpu_test.yml index 29ff99e3504f..390abf141803 100644 --- a/ops/conda_env/macos_cpu_test.yml +++ b/ops/conda_env/macos_cpu_test.yml @@ -6,8 +6,6 @@ dependencies: - pip - wheel - pyyaml -- cpplint -- pylint - numpy - scipy - llvm-openmp @@ -20,22 +18,17 @@ dependencies: - python-graphviz - hypothesis - astroid -- sphinx - sh -- recommonmark -- mock -- breathe - pytest - pytest-cov +- pytest-timeout - python-kubernetes - urllib3 - jsonschema - boto3 - awscli -- loky +- loky>=3.5.1 - pyarrow -- pyspark>=3.4.0 - cloudpickle - pip: - setuptools - - sphinx_rtd_theme diff --git a/ops/conda_env/win64_test.yml b/ops/conda_env/win64_test.yml index 32b9339e6fc0..6e87e1560c21 100644 --- a/ops/conda_env/win64_test.yml +++ b/ops/conda_env/win64_test.yml @@ -16,5 +16,5 @@ dependencies: - python-graphviz - pip - py-ubjson -- loky +- loky>=3.5.1 - pyarrow diff --git a/tests/python/test_collective.py b/tests/python/test_collective.py index 473b38b5b742..1204c0faf8c9 100644 --- a/tests/python/test_collective.py +++ b/tests/python/test_collective.py @@ -3,7 +3,6 @@ import numpy as np import pytest -from loky import get_reusable_executor import xgboost as xgb from xgboost import RabitTracker, build_info, federated @@ -25,10 +24,13 @@ def run_rabit_worker(rabit_env: dict, world_size: int) -> int: @pytest.mark.skipif(**tm.no_loky()) def test_rabit_communicator() -> None: + from loky import get_reusable_executor + world_size = 2 tracker = RabitTracker(host_ip="127.0.0.1", n_workers=world_size) tracker.start() workers = [] + with get_reusable_executor(max_workers=world_size) as pool: for _ in range(world_size): worker = pool.submit( @@ -60,6 +62,8 @@ def run_federated_worker(port: int, world_size: int, rank: int) -> int: @pytest.mark.skipif(**tm.skip_win()) @pytest.mark.skipif(**tm.no_loky()) def test_federated_communicator() -> None: + from loky import get_reusable_executor + if not build_info()["USE_FEDERATED"]: pytest.skip("XGBoost not built with federated learning enabled") From dec7f5896e50191f5b5fb2314c09c2f76304a47b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 20 Mar 2025 01:58:05 +0800 Subject: [PATCH 011/224] [EM] Optimize single batch. (#11339) --- src/data/ellpack_page_raw_format.cu | 4 ++-- src/data/ellpack_page_raw_format.h | 4 ++-- src/data/ellpack_page_source.cu | 6 ++++++ src/data/ellpack_page_source.h | 4 ++++ src/data/extmem_quantile_dmatrix.cu | 8 ++++---- tests/cpp/data/test_ellpack_page_raw_format.cu | 4 ++++ 6 files changed, 22 insertions(+), 8 deletions(-) diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 2907174a0920..955cea2d5c88 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -75,7 +75,7 @@ template return true; } -[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page, +[[nodiscard]] std::size_t EllpackPageRawFormat::Write(EllpackPage const& page, common::AlignedFileWriteStream* fo) { xgboost_NVTX_FN_RANGE(); @@ -109,7 +109,7 @@ template return true; } -[[nodiscard]] std::size_t EllpackPageRawFormat::Write(const EllpackPage& page, +[[nodiscard]] std::size_t EllpackPageRawFormat::Write(EllpackPage const& page, EllpackHostCacheStream* fo) const { xgboost_NVTX_FN_RANGE(); diff --git a/src/data/ellpack_page_raw_format.h b/src/data/ellpack_page_raw_format.h index 9be2c50cff46..eda0e1d20978 100644 --- a/src/data/ellpack_page_raw_format.h +++ b/src/data/ellpack_page_raw_format.h @@ -38,11 +38,11 @@ class EllpackPageRawFormat : public SparsePageFormat { param_{std::move(param)}, has_hmm_ats_{has_hmm_ats} {} [[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override; - [[nodiscard]] std::size_t Write(const EllpackPage& page, + [[nodiscard]] std::size_t Write(EllpackPage const& page, common::AlignedFileWriteStream* fo) override; [[nodiscard]] bool Read(EllpackPage* page, EllpackHostCacheStream* fi) const; - [[nodiscard]] std::size_t Write(const EllpackPage& page, EllpackHostCacheStream* fo) const; + [[nodiscard]] std::size_t Write(EllpackPage const& page, EllpackHostCacheStream* fo) const; }; #if !defined(XGBOOST_USE_CUDA) diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 8dbf2d3ec696..cd99de0d38b0 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -338,6 +338,12 @@ void CalcCacheMapping(Context const* ctx, bool is_dense, cinfo->cache_mapping = std::move(cache_mapping); cinfo->buffer_bytes = std::move(cache_bytes); cinfo->buffer_rows = std::move(cache_rows); + + // Directly store in device if there's only one batch. + if (cinfo->NumBatchesCc() == 1) { + cinfo->prefer_device = true; + LOG(INFO) << "Prefer device cache as there's only 1 page."; + } } /** diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index a668c39bdef4..d8d6e139c83a 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -38,6 +38,10 @@ struct EllpackCacheInfo { prefer_device{prefer_device}, max_num_device_pages{max_num_device_pages}, missing{missing} {} + + // Only effective for host-based cache. + // The number of batches for the concatenated cache. + [[nodiscard]] std::size_t NumBatchesCc() const { return this->buffer_rows.size(); } }; // We need to decouple the storage and the view of the storage so that we can implement diff --git a/src/data/extmem_quantile_dmatrix.cu b/src/data/extmem_quantile_dmatrix.cu index 533d68b2b915..a633ac984e89 100644 --- a/src/data/extmem_quantile_dmatrix.cu +++ b/src/data/extmem_quantile_dmatrix.cu @@ -58,14 +58,14 @@ void ExtMemQuantileDMatrix::InitFromCUDA( /** * Calculate cache info */ - // Prefer device storage for validation dataset since we can't hide it's data load - // overhead with inference. But the training procedures can confortably overlap with the - // data transfer. + // Prefer device storage for validation dataset since we can't hide the data loading + // overhead with inference. On the other hand, training procedures can confortably + // overlap with the data transfer. auto cinfo = EllpackCacheInfo{p, (ref != nullptr), config.max_num_device_pages, config.missing}; CalcCacheMapping(ctx, this->info_.IsDense(), cuts, DftMinCachePageBytes(config.min_cache_page_bytes), ext_info, &cinfo); CHECK_EQ(cinfo.cache_mapping.size(), ext_info.n_batches); - auto n_batches = cinfo.buffer_rows.size(); // The number of batches after page concatenation. + auto n_batches = cinfo.NumBatchesCc(); LOG(INFO) << "Number of batches after concatenation:" << n_batches; /** diff --git a/tests/cpp/data/test_ellpack_page_raw_format.cu b/tests/cpp/data/test_ellpack_page_raw_format.cu index 2351089f6f4d..216736e05f55 100644 --- a/tests/cpp/data/test_ellpack_page_raw_format.cu +++ b/tests/cpp/data/test_ellpack_page_raw_format.cu @@ -24,6 +24,10 @@ namespace { CalcCacheMapping(ctx, Xy->IsDense(), cuts, 0, ext_info, &cinfo); CHECK_EQ(ext_info.n_batches, cinfo.cache_mapping.size()); + if (cinfo.NumBatchesCc() == 1) { + EXPECT_TRUE(cinfo.prefer_device); + cinfo.prefer_device = false; // We test the host cache. + } return cinfo; } From 66d83eed03cd7bd7610dabd82ceb1a25259ca715 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Thu, 20 Mar 2025 17:41:54 +0800 Subject: [PATCH 012/224] [CI] Remove nccl RAS workaround. (#11349) --- ops/docker_run.py | 1 - ops/pipeline/test-python-wheel-impl.sh | 1 - 2 files changed, 2 deletions(-) diff --git a/ops/docker_run.py b/ops/docker_run.py index ba6c8e8c98c0..949f7fb7807d 100644 --- a/ops/docker_run.py +++ b/ops/docker_run.py @@ -70,7 +70,6 @@ def docker_run( docker_run_cli_args.extend( itertools.chain.from_iterable([["-e", f"{k}={v}"] for k, v in user_ids.items()]) ) - docker_run_cli_args.extend(["-e", "NCCL_RAS_ENABLE=0"]) docker_run_cli_args.extend(extra_args) docker_run_cli_args.append(image_uri) docker_run_cli_args.extend(command_args) diff --git a/ops/pipeline/test-python-wheel-impl.sh b/ops/pipeline/test-python-wheel-impl.sh index 4620e6ebf7fc..5c24e31210d2 100755 --- a/ops/pipeline/test-python-wheel-impl.sh +++ b/ops/pipeline/test-python-wheel-impl.sh @@ -45,7 +45,6 @@ case "$suite" in mgpu) echo "-- Run Python tests, using multiple GPUs" python -c 'from cupy.cuda import jitify; jitify._init_module()' - export NCCL_RAS_ENABLE=0 pytest -v -s -rxXs --durations=0 -m 'mgpu' tests/python-gpu pytest -v -s -rxXs --durations=0 tests/test_distributed/test_gpu_with_dask pytest -v -s -rxXs --durations=0 tests/test_distributed/test_gpu_with_spark From 8536af5f5df1a82a4e9b924c68bc386dda591a5d Mon Sep 17 00:00:00 2001 From: jakirkham Date: Thu, 20 Mar 2025 08:27:57 -0700 Subject: [PATCH 013/224] Use RMM's pached CCCL (#11351) Make sure to search for RMM if it will be used. This should pick up the patched CCCL from RMM. If RMM is not being used and this is a CUDA build, search for CCCL explicitly. --- CMakeLists.txt | 52 +++++++++++++++++++++++++++----------------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 845347ea1ad6..ee18a2afdf96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -229,30 +229,6 @@ if(USE_CUDA) endif() find_package(CUDAToolkit REQUIRED) - find_package(CCCL CONFIG) - if(CCCL_FOUND) - message(STATUS "Standalone CCCL found.") - else() - message(STATUS "Standalone CCCL not found. Attempting to use CCCL from CUDA Toolkit...") - find_package(CCCL CONFIG - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - if(NOT CCCL_FOUND) - message(STATUS "Could not locate CCCL from CUDA Toolkit. Using Thrust and CUB from CUDA Toolkit...") - find_package(libcudacxx CONFIG REQUIRED - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - find_package(CUB CONFIG REQUIRED - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - find_package(Thrust CONFIG REQUIRED - HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) - thrust_create_target(Thrust HOST CPP DEVICE CUDA) - add_library(CCCL::CCCL INTERFACE IMPORTED GLOBAL) - target_link_libraries(CCCL::CCCL INTERFACE libcudacxx::libcudacxx CUB::CUB Thrust) - endif() - endif() - # Define guard macros to prevent windows.h from conflicting with winsock2.h - if(WIN32) - target_compile_definitions(CCCL::CCCL INTERFACE NOMINMAX WIN32_LEAN_AND_MEAN _WINSOCKAPI_) - endif() endif() if(FORCE_COLORED_OUTPUT AND (CMAKE_GENERATOR STREQUAL "Ninja") AND @@ -338,6 +314,34 @@ if(PLUGIN_RMM) list(REMOVE_ITEM rmm_link_libs CUDA::cudart) list(APPEND rmm_link_libs CUDA::cudart_static) set_target_properties(rmm::rmm PROPERTIES INTERFACE_LINK_LIBRARIES "${rmm_link_libs}") + + # Pick up patched CCCL from RMM +elseif(USE_CUDA) + # If using CUDA and not RMM, search for CCCL. + find_package(CCCL CONFIG) + if(CCCL_FOUND) + message(STATUS "Standalone CCCL found.") + else() + message(STATUS "Standalone CCCL not found. Attempting to use CCCL from CUDA Toolkit...") + find_package(CCCL CONFIG + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + if(NOT CCCL_FOUND) + message(STATUS "Could not locate CCCL from CUDA Toolkit. Using Thrust and CUB from CUDA Toolkit...") + find_package(libcudacxx CONFIG REQUIRED + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + find_package(CUB CONFIG REQUIRED + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + find_package(Thrust CONFIG REQUIRED + HINTS ${CUDAToolkit_LIBRARY_DIR}/cmake) + thrust_create_target(Thrust HOST CPP DEVICE CUDA) + add_library(CCCL::CCCL INTERFACE IMPORTED GLOBAL) + target_link_libraries(CCCL::CCCL INTERFACE libcudacxx::libcudacxx CUB::CUB Thrust) + endif() + endif() + # Define guard macros to prevent windows.h from conflicting with winsock2.h + if(WIN32) + target_compile_definitions(CCCL::CCCL INTERFACE NOMINMAX WIN32_LEAN_AND_MEAN _WINSOCKAPI_) + endif() endif() if(PLUGIN_SYCL) From 0500992cccd1a695fe0030184fb84d0a5f3d703b Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sat, 22 Mar 2025 14:14:31 +0800 Subject: [PATCH 014/224] Implement ordinal recoder for the GPU predictor. (#11347) - Unify the code path for various GPU prediction functions. - Implement re-coding for the GPU predictor. --- include/xgboost/c_api.h | 44 +- python-package/xgboost/core.py | 12 +- python-package/xgboost/testing/ordinal.py | 120 ++++- src/data/cat_container.cuh | 11 +- src/data/iterative_dmatrix.cc | 4 +- src/encoder/ordinal.cuh | 41 +- src/encoder/ordinal.h | 25 +- src/predictor/cpu_predictor.cc | 9 +- src/predictor/gpu_predictor.cu | 560 ++++++++++++++-------- tests/cpp/encoder/test_ordinal.cc | 17 +- tests/cpp/encoder/test_ordinal.h | 12 + tests/python-gpu/test_gpu_linear.py | 9 +- tests/python-gpu/test_gpu_ordinal.py | 84 ++++ tests/python/test_ordinal.py | 5 + 14 files changed, 723 insertions(+), 230 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 852f65d38f52..b268e84f4ab4 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -162,7 +162,49 @@ XGB_DLL int XGDMatrixCreateFromCSREx(const size_t *indptr, const unsigned *indic * @brief Create a DMatrix from columnar data. (table) * * A special type of input to the `DMatrix` is the columnar format, which refers to - * column-based dataframes based on the arrow formatt. + * column-based dataframes. XGBoost can accept both numeric data types like integers and + * floats, along with the categorical type, called dictionary in arrow's term. The + * addition of categorical type is introduced in 3.1.0. The dataframe is represented by a + * list array interfaces with one object for each column. + * + * A categorical type is represented by 3 buffers, the validity mask, the names of the + * categories (called index for most of the dataframe implementation), and the codes used + * to represent the categories in the rows. XGBoost consumes a categorical column by + * accepting two JSON-encoded arrow arrays in a list. The first item in the list is a JSON + * object with `{"offsets": IntegerArray, "values": StringArray }` representing the string + * names defined by the arrow columnar format. The second buffer is an masked integer + * array that stores the categorical codes along with the validity mask: + * + * @code{javascript} + * [ + * // categorical column, represented as an array (list) + * [ + * { + * 'offsets': + * { + * 'data': (129412626415808, True), + * 'typestr': ' Tuple[Type, Type]: return Df, Ser +def asarray(device: str, data: Any) -> np.ndarray: + """Wrapper to get an array.""" + if device == "cpu": + return np.asarray(data) + import cupy as cp + + return cp.asarray(data) + + def assert_allclose(device: str, a: Any, b: Any) -> None: """Dispatch the assert_allclose for devices.""" if device == "cpu": @@ -273,12 +282,12 @@ def run_mixed(DMatrixT: Type) -> None: # used with the next df b_codes = df.b.cat.codes - np.testing.assert_allclose(np.asarray(b_codes), np.array([1, 0, 2])) + assert_allclose(device, asarray(device, b_codes), np.array([1, 0, 2])) # pick codes of 3, 1 b_encoded = np.array([b_codes.iloc[2], b_codes.iloc[1]]) c_codes = df.c.cat.codes - np.testing.assert_allclose(np.asarray(c_codes), np.array([1, 0, 2])) + assert_allclose(device, asarray(device, c_codes), np.array([1, 0, 2])) # pick codes of "def", "abc" c_encoded = np.array([c_codes.iloc[2], c_codes.iloc[1]]) encoded = np.stack([b_encoded, c_encoded], axis=1) @@ -317,13 +326,19 @@ def run_invalid(DMatrixT: Type) -> None: with pytest.raises(ValueError, match="The data type doesn't match"): booster.predict(Xy) + df = Df( + {"b": [2, 1, 3, 4], "c": ["cdef", "abc", "def", "bbc"]}, dtype="category" + ) + with pytest.raises(ValueError, match="Found a category not in the training"): + booster.inplace_predict(df) + for dm in (DMatrix, QuantileDMatrix): run_invalid(dm) def run_cat_thread_safety(device: Literal["cpu", "cuda"]) -> None: """Basic tests for thread safety.""" - X, y = make_categorical(2048, 16, 112, onehot=False, cat_ratio=0.5) + X, y = make_categorical(2048, 16, 112, onehot=False, cat_ratio=0.5, device=device) Xy = QuantileDMatrix(X, y, enable_categorical=True) booster = train({"device": device}, Xy, num_boost_round=10) @@ -412,3 +427,102 @@ def run_cat_leaf(device: Literal["cpu", "cuda"]) -> None: _run_predt( device, DMatrix, pred_contribs=False, pred_interactions=False, pred_leaf=True ) + + +def run_specified_cat( # pylint: disable=too-many-locals + device: Literal["cpu", "cuda"], +) -> None: + """Run with manually specified category encoding.""" + import pandas as pd + + # Same between old and new, wiht 0 ("a") and 1 ("b") exchanged their position. + old_cats = ["a", "b", "c", "d"] + new_cats = ["b", "a", "c", "d"] + mapping = {0: 1, 1: 0} + + col0 = np.arange(0, 9) + col1 = pd.Categorical.from_codes( + # b, b, c, d, a, c, c, d, a + categories=old_cats, + codes=[1, 1, 2, 3, 0, 2, 2, 3, 0], + ) + df = pd.DataFrame({"f0": col0, "f1": col1}) + Df, _ = get_df_impl(device) + df = Df(df) + rng = np.random.default_rng(2025) + y = rng.uniform(size=df.shape[0]) + + for dm in (DMatrix, QuantileDMatrix): + Xy = dm(df, y, enable_categorical=True) + booster = train({"device": device}, Xy) + predt0 = booster.predict(Xy) + predt1 = booster.inplace_predict(df) + assert_allclose(device, predt0, predt1) + + col1 = pd.Categorical.from_codes( + # b, b, c, d, a, c, c, d, a + categories=new_cats, + codes=[0, 0, 2, 3, 1, 2, 2, 3, 1], + ) + df1 = Df({"f0": col0, "f1": col1}) + predt2 = booster.inplace_predict(df1) + assert_allclose(device, predt0, predt2) + + # Test large column numbers. XGBoost makes some specializations for slim datasets, + # make sure we cover all the cases. + n_features = 4096 + n_samples = 1024 + df = pd.DataFrame() + col_numeric = rng.uniform(0, 1, size=(n_samples, n_features // 2)) + col_categorical = rng.integers( + low=0, high=4, size=(n_samples, n_features // 2), dtype=np.int32 + ) + + for c in range(n_features): + if c % 2 == 0: + col = col_numeric[:, c // 2] + else: + codes = col_categorical[:, c // 2] + col = pd.Categorical.from_codes( + categories=old_cats, + codes=codes, + ) + df[f"f{c}"] = col + + df = Df(df) + y = rng.normal(size=n_samples) + + Xy = DMatrix(df, y, enable_categorical=True) + booster = train({"device": device}, Xy) + + predt0 = booster.predict(Xy) + predt1 = booster.inplace_predict(df) + assert_allclose(device, predt0, predt1) + + for c in range(n_features): + if c % 2 == 0: + continue + + name = f"f{c}" + codes_ser = df[name].cat.codes + if hasattr(codes_ser, "to_pandas"): # cudf + codes_ser = codes_ser.to_pandas() + new_codes = codes_ser.replace(mapping) + df[name] = pd.Categorical.from_codes(categories=new_cats, codes=new_codes) + + df = Df(df) + Xy = DMatrix(df, y, enable_categorical=True) + predt2 = booster.predict(Xy) + assert_allclose(device, predt0, predt2) + + array = np.empty(shape=(n_samples, n_features)) + array[:, np.arange(0, n_features) % 2 == 0] = col_numeric + array[:, np.arange(0, n_features) % 2 != 0] = col_categorical + + if device == "cuda": + import cupy as cp + + array = cp.array(array) + + predt3 = booster.inplace_predict(array) + assert_allclose(device, predt0, predt3) diff --git a/src/data/cat_container.cuh b/src/data/cat_container.cuh index 8cfbf6ee16e1..9522a97c856a 100644 --- a/src/data/cat_container.cuh +++ b/src/data/cat_container.cuh @@ -60,13 +60,12 @@ struct EncThrustPolicy { template using ThrustAllocator = dh::XGBDeviceAllocator; - auto ThrustPolicy() const { -#if defined(XGBOOST_USE_RMM) - return rmm::exec_policy_nosync{}; -#else - return dh::CachingThrustPolicy(); -#endif // defined(XGBOOST_USE_RMM) + [[nodiscard]] auto ThrustPolicy() const { + dh::XGBCachingDeviceAllocator alloc; + auto exec = thrust::cuda::par_nosync(alloc).on(dh::DefaultStream()); + return exec; } + [[nodiscard]] auto Stream() const { return dh::DefaultStream(); } }; using EncPolicyT = enc::Policy; diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 3f59af9ffda2..2d6f7451d43d 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -51,7 +51,7 @@ IterativeDMatrix::IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle pro this->batch_ = p; LOG(INFO) << "Finished constructing the `IterativeDMatrix`: (" << this->Info().num_row_ << ", " - << this->Info().num_col_ << ", " << this->Info().num_nonzero_ << ")."; + << this->Info().num_col_ << ", " << this->info_.num_nonzero_ << ")."; } void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p, @@ -110,7 +110,7 @@ void IterativeDMatrix::InitFromCPU(Context const* ctx, BatchParam const& p, accumulated_rows += BatchSamples(proxy); } iter.Reset(); - CHECK_EQ(accumulated_rows, Info().num_row_); + CHECK_EQ(accumulated_rows, this->info_.num_row_); if (ext_info.n_batches == 1) { this->info_ = std::move(proxy->Info()); diff --git a/src/encoder/ordinal.cuh b/src/encoder/ordinal.cuh index 282441d4a0d3..42300e4c38ef 100644 --- a/src/encoder/ordinal.cuh +++ b/src/encoder/ordinal.cuh @@ -98,8 +98,8 @@ struct SegmentedSearchSortedNumOp { haystack_v.feature_segments[f_idx + 1] - haystack_v.feature_segments[f_idx]); auto end_it = it + f_sorted_idx.size(); auto ret_it = thrust::lower_bound(thrust::seq, it, end_it, SearchKey(), [&](auto l, auto r) { - T l_value = l == SearchKey() ? needle : haystack[ref_sorted_idx[l]]; - T r_value = r == SearchKey() ? needle : haystack[ref_sorted_idx[r]]; + T l_value = l == SearchKey() ? needle : haystack[f_sorted_idx[l]]; + T r_value = r == SearchKey() ? needle : haystack[f_sorted_idx[r]]; return l_value < r_value; }); if (ret_it == it + f_sorted_idx.size()) { @@ -122,7 +122,8 @@ struct DftThrustPolicy { template using ThrustAllocator = thrust::device_allocator; - auto ThrustPolicy() const { return thrust::cuda::par_nosync; } + [[nodiscard]] auto ThrustPolicy() const { return thrust::cuda::par_nosync; } + [[nodiscard]] auto Stream() const { return cudaStreamPerThread; } }; } // namespace cuda_impl @@ -144,12 +145,15 @@ using DftDevicePolicy = Policy void SortNames(ExecPolicy const& policy, DeviceColumnsView orig_enc, Span sorted_idx) { + typename ExecPolicy::template ThrustAllocator alloc; + auto exec = thrust::cuda::par_nosync(alloc).on(policy.Stream()); + auto n_total_cats = orig_enc.n_total_cats; if (static_cast(sorted_idx.size()) != orig_enc.n_total_cats) { policy.Error("`sorted_idx` should have the same size as `n_total_cats`."); } auto d_sorted_idx = dh::ToSpan(sorted_idx); - cuda_impl::SegmentedIota(policy.ThrustPolicy(), orig_enc.feature_segments, d_sorted_idx); + cuda_impl::SegmentedIota(exec, orig_enc.feature_segments, d_sorted_idx); // using Pair = cuda::std::pair; @@ -162,9 +166,9 @@ void SortNames(ExecPolicy const& policy, DeviceColumnsView orig_enc, auto idx = d_sorted_idx[i]; return cuda::std::make_pair(static_cast(seg), idx); })); - thrust::copy(policy.ThrustPolicy(), key_it, key_it + n_total_cats, keys.begin()); + thrust::copy(exec, key_it, key_it + n_total_cats, keys.begin()); - thrust::sort(policy.ThrustPolicy(), keys.begin(), keys.end(), + thrust::sort(exec, keys.begin(), keys.end(), cuda::proclaim_return_type([=] __device__(Pair const& l, Pair const& r) { if (l.first == r.first) { // same feature auto const& col = orig_enc.columns[l.first]; @@ -193,7 +197,7 @@ void SortNames(ExecPolicy const& policy, DeviceColumnsView orig_enc, thrust::make_counting_iterator(0), cuda::proclaim_return_type( [=] __device__(std::int32_t i) { return s_keys[i].second; })); - thrust::copy(policy.ThrustPolicy(), it, it + sorted_idx.size(), dh::tbegin(sorted_idx)); + thrust::copy(exec, it, it + sorted_idx.size(), dh::tbegin(sorted_idx)); } /** @@ -212,8 +216,27 @@ template void Recode(ExecPolicy const& policy, DeviceColumnsView orig_enc, Span sorted_idx, DeviceColumnsView new_enc, Span mapping) { - auto exec = policy.ThrustPolicy(); + typename ExecPolicy::template ThrustAllocator alloc; + auto exec = thrust::cuda::par_nosync(alloc).on(policy.Stream()); detail::BasicChecks(policy, orig_enc, sorted_idx, new_enc, mapping); + /** + * Check consistency. + */ + auto check_it = thrust::make_transform_iterator( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(std::size_t i) { + auto l_f = orig_enc.columns[i]; + auto r_f = new_enc.columns[i]; + auto l_is_empty = cuda::std::visit([](auto&& arg) { return arg.empty(); }, l_f); + auto r_is_empty = cuda::std::visit([](auto&& arg) { return arg.empty(); }, r_f); + return l_is_empty == r_is_empty; + }); + bool valid = thrust::reduce(exec, check_it, check_it + new_enc.Size(), true, + [=] XGBOOST_DEVICE(bool l, bool r) { return l && r; }); + if (!valid) { + policy.Error( + "Invalid new DataFrame. " + "The data type doesn't match the one used in the training dataset."); + } /** * search the index for the new encoding @@ -222,7 +245,7 @@ void Recode(ExecPolicy const& policy, DeviceColumnsView orig_enc, exec, thrust::make_counting_iterator(0), new_enc.n_total_cats, [=] __device__(std::int32_t i) { auto f_idx = dh::SegmentId(new_enc.feature_segments, i); - std::int32_t searched_idx{-1}; + std::int32_t searched_idx{detail::NotFound()}; auto const& col = orig_enc.columns[f_idx]; cuda::std::visit(Overloaded{[&](CatStrArrayView const& str) { auto op = cuda_impl::SegmentedSearchSortedStrOp{ diff --git a/src/encoder/ordinal.h b/src/encoder/ordinal.h index bfb334d29666..d4de6d0c8a59 100644 --- a/src/encoder/ordinal.h +++ b/src/encoder/ordinal.h @@ -107,7 +107,8 @@ using DeviceCatIndexView = cuda_impl::TupToVarT; * Accepted policies: * * - A class with a `ThrustPolicy` method that returns a thrust execution policy, along with a - * `ThrustAllocator` template type. This is only used for the GPU implementation. + * `ThrustAllocator` template type. In addition, a `Stream` method that returns a CUDA stream. + * This is only used for the GPU implementation. * * - An error handling policy that exposes a single `Error` method, which takes a single * string parameter for error message. @@ -133,6 +134,7 @@ struct ColumnsViewImpl { [[nodiscard]] std::size_t Size() const { return columns.size(); } [[nodiscard]] bool Empty() const { return this->Size() == 0; } [[nodiscard]] auto operator[](std::size_t i) const { return columns[i]; } + [[nodiscard]] auto HasCategorical() const { return n_total_cats != 0; } }; struct DftErrorHandler { @@ -418,4 +420,25 @@ inline std::ostream &operator<<(std::ostream &os, CatStrArrayView const &strings os << "]"; return os; } + +inline std::ostream &operator<<(std::ostream &os, HostColumnsView const &h_enc) { + for (std::size_t i = 0; i < h_enc.columns.size(); ++i) { + auto const &col = h_enc.columns[i]; + os << "f" << i << ": "; + std::visit(enc::Overloaded{[&](enc::CatStrArrayView const &str) { os << str; }, + [&](auto &&values) { + os << "["; + for (std::size_t j = 0, n = values.size(); j < n; ++j) { + os << values[j]; + if (j != n - 1) { + os << ", "; + } + } + os << "]"; + }}, + col); + os << std::endl; + } + return os; +} } // namespace enc diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index d986882a6795..c82ece98d83c 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -367,13 +367,14 @@ static void InitThreadTemp(int nthread, std::vector *out) { } } -auto MakeCatAccessor(Context const *ctx, enc::HostColumnsView const &cats, +auto MakeCatAccessor(Context const *ctx, enc::HostColumnsView const &new_enc, gbm::GBTreeModel const &model) { - std::vector mapping(cats.n_total_cats); + std::vector mapping(new_enc.n_total_cats); auto sorted_idx = model.Cats()->RefSortedIndex(ctx); auto orig_enc = model.Cats()->HostView(); - enc::Recode(cpu_impl::EncPolicy, orig_enc, sorted_idx, cats, common::Span{mapping}); - auto cats_mapping = enc::MappingView{cats.feature_segments, mapping}; + enc::Recode(cpu_impl::EncPolicy, orig_enc, sorted_idx, new_enc, common::Span{mapping}); + CHECK_EQ(new_enc.feature_segments.size(), orig_enc.feature_segments.size()); + auto cats_mapping = enc::MappingView{new_enc.feature_segments, mapping}; auto acc = CatAccessor{cats_mapping}; return std::tuple{acc, std::move(mapping)}; } diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index d99f00cd35a4..f00641d9f5a7 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -16,11 +16,13 @@ #include "../common/cuda_context.cuh" // for CUDAContext #include "../common/cuda_rt_utils.h" // for AllVisibleGPUs, SetDevice #include "../common/device_helpers.cuh" -#include "../common/error_msg.h" // for InplacePredictProxy -#include "../data/batch_utils.h" // for StaticBatch +#include "../common/error_msg.h" // for InplacePredictProxy +#include "../data/batch_utils.h" // for StaticBatch +#include "../data/cat_container.cuh" // for EncPolicy #include "../data/device_adapter.cuh" #include "../data/ellpack_page.cuh" #include "../data/proxy_dmatrix.h" +#include "../encoder/ordinal.cuh" // for CudaCategoryRecoder #include "../gbm/gbtree_model.h" #include "predict_fn.h" #include "xgboost/data.h" @@ -74,9 +76,8 @@ struct SparsePageView { SparsePageView() = default; XGBOOST_DEVICE SparsePageView(common::Span data, - common::Span row_ptr, - bst_feature_t num_features) - : d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {} + common::Span row_ptr, bst_feature_t n_features) + : d_data{data}, d_row_ptr{row_ptr}, num_features(n_features) {} [[nodiscard]] __device__ float GetElement(size_t ridx, size_t fidx) const { // Binary search auto begin_ptr = d_data.begin() + d_row_ptr[ridx]; @@ -109,14 +110,19 @@ struct SparsePageView { [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } }; +template struct SparsePageLoader { + private: + EncAccessor acc_; + + public: bool use_shared; SparsePageView data; float* smem; __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, - bst_idx_t num_rows, float) - : use_shared(use_shared), data(data) { + bst_idx_t num_rows, float, EncAccessor&& acc) + : use_shared(use_shared), data(data), acc_{std::forward(acc)} { extern __shared__ float _smem[]; smem = _smem; // Copy instances @@ -130,7 +136,7 @@ struct SparsePageLoader { bst_uint elem_end = data.d_row_ptr[global_idx + 1]; for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) { Entry elem = data.d_data[elem_idx]; - smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue; + smem[threadIdx.x * data.num_features + elem.index] = this->acc_(elem); } } __syncthreads(); @@ -140,22 +146,27 @@ struct SparsePageLoader { if (use_shared) { return smem[threadIdx.x * data.num_features + fidx]; } else { - return data.GetElement(ridx, fidx); + return this->acc_(data.GetElement(ridx, fidx), fidx); } } }; +template struct EllpackLoader { EllpackDeviceAccessor matrix; - XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor m, bool, bst_feature_t, bst_idx_t, float) - : matrix{std::move(m)} {} + EncAccessor acc; + + XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor m, bool /*use_shared*/, + bst_feature_t /*n_features*/, bst_idx_t /*n_samples*/, + float /*missing*/, EncAccessor&& acc) + : matrix{std::move(m)}, acc{std::forward(acc)} {} [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { auto gidx = matrix.GetBinIndex(ridx, fidx); if (gidx == -1) { return std::numeric_limits::quiet_NaN(); } if (common::IsCat(matrix.feature_types, fidx)) { - return matrix.gidx_fvalue_map[gidx]; + return this->acc(matrix.gidx_fvalue_map[gidx], fidx); } // The gradient index needs to be shifted by one as min values are not included in the // cuts. @@ -168,34 +179,45 @@ struct EllpackLoader { [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return this->matrix.n_rows; } }; -template +/** + * @brief Use for in-place predict. + */ +template struct DeviceAdapterLoader { - Batch batch; - bst_feature_t columns; + private: + Batch batch_; + EncAccessor acc_; + + public: + bst_feature_t n_features; float* smem; bool use_shared; data::IsValidFunctor is_valid; + using BatchT = Batch; - XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared, - bst_feature_t num_features, bst_idx_t num_rows, - float missing) - : batch{batch}, columns{num_features}, use_shared{use_shared}, is_valid{missing} { + XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch&& batch, bool use_shared, bst_feature_t n_features, + bst_idx_t n_samples, float missing, EncAccessor&& acc) + : batch_{std::move(batch)}, + acc_{std::forward(acc)}, + n_features{n_features}, + use_shared{use_shared}, + is_valid{missing} { extern __shared__ float _smem[]; - smem = _smem; - if (use_shared) { - uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - size_t shared_elements = blockDim.x * num_features; + this->smem = _smem; + if (this->use_shared) { + auto global_idx = blockDim.x * blockIdx.x + threadIdx.x; + size_t shared_elements = blockDim.x * n_features; dh::BlockFill(smem, shared_elements, std::numeric_limits::quiet_NaN()); __syncthreads(); - if (global_idx < num_rows) { - auto beg = global_idx * columns; - auto end = (global_idx + 1) * columns; + if (global_idx < n_samples) { + auto beg = global_idx * n_features; + auto end = (global_idx + 1) * n_features; for (size_t i = beg; i < end; ++i) { - auto value = batch.GetElement(i).value; - if (is_valid(value)) { - smem[threadIdx.x * num_features + (i - beg)] = value; + data::COOTuple const& e = this->batch_.GetElement(i); + if (is_valid(e)) { + smem[threadIdx.x * n_features + (i - beg)] = this->acc_(e); } } } @@ -205,11 +227,11 @@ struct DeviceAdapterLoader { [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { if (use_shared) { - return smem[threadIdx.x * columns + fidx]; + return smem[threadIdx.x * n_features + fidx]; } - auto value = batch.GetElement(ridx * columns + fidx).value; + auto value = this->batch_.GetElement(ridx * n_features + fidx).value; if (is_valid(value)) { - return value; + return this->acc_(value, fidx); } else { return std::numeric_limits::quiet_NaN(); } @@ -241,7 +263,7 @@ __device__ float GetLeafWeight(bst_idx_t ridx, TreeView const &tree, return tree.d_tree[nidx].LeafValue(); } -template +template __global__ void PredictLeafKernel(Data data, common::Span d_nodes, common::Span d_out_predictions, @@ -254,12 +276,12 @@ PredictLeafKernel(Data data, common::Span d_nodes, bst_tree_t tree_begin, bst_tree_t tree_end, bst_feature_t num_features, size_t num_rows, bool use_shared, - float missing) { + float missing, EncAccessor acc) { bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x; if (ridx >= num_rows) { return; } - Loader loader{data, use_shared, num_features, num_rows, missing}; + Loader loader{std::move(data), use_shared, num_features, num_rows, missing, std::move(acc)}; for (bst_tree_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { TreeView d_tree{ tree_begin, tree_idx, d_nodes, @@ -268,15 +290,15 @@ PredictLeafKernel(Data data, common::Span d_nodes, bst_node_t leaf = -1; if (d_tree.HasCategoricalSplit()) { - leaf = GetLeafIndex(ridx, d_tree, &loader); + leaf = GetLeafIndex(ridx, d_tree, &loader); } else { - leaf = GetLeafIndex(ridx, d_tree, &loader); + leaf = GetLeafIndex(ridx, d_tree, &loader); } d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf; } } -template +template __global__ void PredictKernel(Data data, common::Span d_nodes, common::Span d_out_predictions, @@ -286,10 +308,10 @@ PredictKernel(Data data, common::Span d_nodes, common::Span d_cat_tree_segments, common::Span d_cat_node_segments, common::Span d_categories, bst_tree_t tree_begin, - bst_tree_t tree_end, size_t num_features, size_t num_rows, - bool use_shared, int num_group, float missing) { + bst_tree_t tree_end, bst_feature_t num_features, size_t num_rows, + bool use_shared, int num_group, float missing, EncAccessor acc) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; - Loader loader(data, use_shared, num_features, num_rows, missing); + Loader loader{std::move(data), use_shared, num_features, num_rows, missing, std::move(acc)}; if (global_idx >= num_rows) return; if (num_group == 1) { @@ -332,11 +354,13 @@ class DeviceModel { HostDeviceVector categories_node_segments; HostDeviceVector categories; - size_t tree_beg_; // NOLINT - size_t tree_end_; // NOLINT + bst_tree_t tree_beg_; // NOLINT + bst_tree_t tree_end_; // NOLINT int num_group; + CatContainer const* cat_enc{nullptr}; - void Init(const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end, DeviceOrd device) { + void Init(const gbm::GBTreeModel& model, bst_tree_t tree_begin, bst_tree_t tree_end, + DeviceOrd device) { dh::safe_cuda(cudaSetDevice(device.ordinal)); // Copy decision trees to device @@ -406,17 +430,21 @@ class DeviceModel { this->tree_beg_ = tree_begin; this->tree_end_ = tree_end; this->num_group = model.learner_model_param->OutputLength(); + + this->cat_enc = model.Cats(); + CHECK(this->cat_enc); } }; struct ShapSplitCondition { ShapSplitCondition() = default; XGBOOST_DEVICE - ShapSplitCondition(float feature_lower_bound, float feature_upper_bound, - bool is_missing_branch, common::CatBitField cats) + ShapSplitCondition(float feature_lower_bound, float feature_upper_bound, bool is_missing_branch, + common::CatBitField cats) : feature_lower_bound(feature_lower_bound), feature_upper_bound(feature_upper_bound), - is_missing_branch(is_missing_branch), categories{std::move(cats)} { + is_missing_branch(is_missing_branch), + categories{std::move(cats)} { assert(feature_lower_bound <= feature_upper_bound); } @@ -624,7 +652,7 @@ __global__ void MaskBitVectorKernel( bst_tree_t tree_begin, bst_tree_t tree_end, bst_feature_t num_features, std::size_t num_rows, std::size_t num_nodes, bool use_shared, float missing) { // This needs to be always instantiated since the data is loaded cooperatively by all threads. - SparsePageLoader loader{data, use_shared, num_features, num_rows, missing}; + SparsePageLoader loader{data, use_shared, num_features, num_rows, missing, NoOpAccessor{}}; auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x; if (row_idx >= num_rows) { return; @@ -841,99 +869,209 @@ class ColumnSplitHelper { Context const* ctx_; }; -} // anonymous namespace -class GPUPredictor : public xgboost::Predictor { +auto MakeCatAccessor(Context const* ctx, enc::DeviceColumnsView const& new_enc, + DeviceModel const& model) { + dh::DeviceUVector mapping(new_enc.n_total_cats); + auto d_sorted_idx = model.cat_enc->RefSortedIndex(ctx); + auto orig_enc = model.cat_enc->DeviceView(ctx); + enc::Recode(cuda_impl::EncPolicy, orig_enc, d_sorted_idx, new_enc, dh::ToSpan(mapping)); + CHECK_EQ(new_enc.feature_segments.size(), orig_enc.feature_segments.size()); + auto cats_mapping = enc::MappingView{new_enc.feature_segments, dh::ToSpan(mapping)}; + auto acc = CatAccessor{cats_mapping}; + return std::tuple{acc, std::move(mapping)}; +} + +template +struct ShapSparsePageView { + SparsePageView data; + EncAccessor acc; + + template + [[nodiscard]] __device__ float GetElement(bst_idx_t ridx, Fidx fidx) const { + auto fvalue = data.GetElement(ridx, fidx); + return acc(fvalue, fidx); + } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumRows() const { return data.NumRows(); } + [[nodiscard]] XGBOOST_DEVICE bst_idx_t NumCols() const { return data.NumCols(); } +}; + +template +void LaunchPredictKernel(Context const* ctx, bool is_dense, enc::DeviceColumnsView const& new_enc, + DeviceModel const& model, Kernel&& launch) { + if (is_dense) { + auto is_dense = std::true_type{}; + if (model.cat_enc->HasCategorical() && new_enc.HasCategorical()) { + auto [acc, mapping] = MakeCatAccessor(ctx, new_enc, model); + launch(is_dense, std::move(acc)); + } else { + launch(is_dense, NoOpAccessor{}); + } + } else { + auto is_dense = std::false_type{}; + if (model.cat_enc->HasCategorical() && new_enc.HasCategorical()) { + auto [acc, mapping] = MakeCatAccessor(ctx, new_enc, model); + launch(is_dense, std::move(acc)); + } else { + launch(is_dense, NoOpAccessor{}); + } + } +} + +// provide configuration for launching the predict kernel. +template +class LaunchConfig { private: - void PredictInternal(const SparsePage& batch, DeviceModel const& model, size_t num_features, - HostDeviceVector* predictions, size_t batch_offset, - bool is_dense) const { - batch.offset.SetDevice(ctx_->Device()); - batch.data.SetDevice(ctx_->Device()); - const uint32_t BLOCK_THREADS = 128; - bst_idx_t num_rows = batch.Size(); - auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); - auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device()); - size_t shared_memory_bytes = - SharedMemoryBytes(num_features, max_shared_memory_bytes); - bool use_shared = shared_memory_bytes != 0; - - SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), - num_features); - auto const kernel = [&](auto predict_fn) { - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes, ctx_->CUDACtx()->Stream()}( - predict_fn, data, model.nodes.ConstDeviceSpan(), + static auto constexpr NotSet() { return std::numeric_limits::max(); } + + Context const* ctx_; + std::size_t const shared_memory_bytes_; + bst_idx_t n_samples_{NotSet()}; + + template + void LaunchImpl(K&& kernel, Args&&... args) const&& { + CHECK_NE(this->n_samples_, NotSet()); + auto grid = static_cast(common::DivRoundUp(this->n_samples_, kBlockThreads)); + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes_, ctx_->CUDACtx()->Stream()}( + kernel, std::forward(args)...); + } + + [[nodiscard]] LaunchConfig Grid(bst_idx_t n_samples) const { + LaunchConfig cfg = *this; + cfg.n_samples_ = n_samples; + return cfg; + } + [[nodiscard]] bool UseShared() const { return shared_memory_bytes_ != 0; } + + [[nodiscard]] static std::size_t ConfigureDevice(DeviceOrd const& device) { + thread_local std::unordered_map max_shared; + if (device.IsCUDA()) { + auto it = max_shared.find(device.ordinal); + if (it == max_shared.cend()) { + max_shared[device.ordinal] = dh::MaxSharedMemory(device.ordinal); + it = max_shared.find(device.ordinal); + } + return it->second; + } + return 0; + } + + public: + LaunchConfig(Context const* ctx, bst_feature_t n_features) + : ctx_{ctx}, + shared_memory_bytes_{kUseShared ? SharedMemoryBytes( + n_features, ConfigureDevice(ctx->Device())) + : 0} {} + + template