Skip to content

Commit b4453f1

Browse files
thomasjpfanNicolasHugogrisel
authored
ENH Add Categorical support for HistGradientBoosting (scikit-learn#18394)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Olivier Grisel <[email protected]> Co-authored-by: Olivier Grisel <[email protected]>
1 parent 04c080a commit b4453f1

24 files changed

+2206
-182
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
from time import time
3+
4+
from sklearn.model_selection import train_test_split
5+
from sklearn.datasets import fetch_openml
6+
from sklearn.metrics import accuracy_score, roc_auc_score
7+
from sklearn.experimental import enable_hist_gradient_boosting # noqa
8+
from sklearn.ensemble import HistGradientBoostingClassifier
9+
from sklearn.ensemble._hist_gradient_boosting.utils import (
10+
get_equivalent_estimator)
11+
12+
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('--n-leaf-nodes', type=int, default=31)
15+
parser.add_argument('--n-trees', type=int, default=100)
16+
parser.add_argument('--lightgbm', action="store_true", default=False)
17+
parser.add_argument('--learning-rate', type=float, default=.1)
18+
parser.add_argument('--max-bins', type=int, default=255)
19+
parser.add_argument('--no-predict', action="store_true", default=False)
20+
parser.add_argument('--verbose', action="store_true", default=False)
21+
args = parser.parse_args()
22+
23+
n_leaf_nodes = args.n_leaf_nodes
24+
n_trees = args.n_trees
25+
lr = args.learning_rate
26+
max_bins = args.max_bins
27+
verbose = args.verbose
28+
29+
30+
def fit(est, data_train, target_train, libname, **fit_params):
31+
print(f"Fitting a {libname} model...")
32+
tic = time()
33+
est.fit(data_train, target_train, **fit_params)
34+
toc = time()
35+
print(f"fitted in {toc - tic:.3f}s")
36+
37+
38+
def predict(est, data_test, target_test):
39+
if args.no_predict:
40+
return
41+
tic = time()
42+
predicted_test = est.predict(data_test)
43+
predicted_proba_test = est.predict_proba(data_test)
44+
toc = time()
45+
roc_auc = roc_auc_score(target_test, predicted_proba_test[:, 1])
46+
acc = accuracy_score(target_test, predicted_test)
47+
print(f"predicted in {toc - tic:.3f}s, "
48+
f"ROC AUC: {roc_auc:.4f}, ACC: {acc :.4f}")
49+
50+
51+
data = fetch_openml(data_id=179, as_frame=False) # adult dataset
52+
X, y = data.data, data.target
53+
54+
n_features = X.shape[1]
55+
n_categorical_features = len(data.categories)
56+
n_numerical_features = n_features - n_categorical_features
57+
print(f"Number of features: {n_features}")
58+
print(f"Number of categorical features: {n_categorical_features}")
59+
print(f"Number of numerical features: {n_numerical_features}")
60+
61+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.2,
62+
random_state=0)
63+
64+
# Note: no need to use an OrdinalEncoder because categorical features are
65+
# already clean
66+
is_categorical = [name in data.categories for name in data.feature_names]
67+
est = HistGradientBoostingClassifier(
68+
loss='binary_crossentropy',
69+
learning_rate=lr,
70+
max_iter=n_trees,
71+
max_bins=max_bins,
72+
max_leaf_nodes=n_leaf_nodes,
73+
categorical_features=is_categorical,
74+
early_stopping=False,
75+
random_state=0,
76+
verbose=verbose
77+
)
78+
79+
fit(est, X_train, y_train, 'sklearn')
80+
predict(est, X_test, y_test)
81+
82+
if args.lightgbm:
83+
est = get_equivalent_estimator(est, lib='lightgbm')
84+
est.set_params(max_cat_to_onehot=1) # dont use OHE
85+
categorical_features = [f_idx
86+
for (f_idx, is_cat) in enumerate(is_categorical)
87+
if is_cat]
88+
fit(est, X_train, y_train, 'lightgbm',
89+
categorical_feature=categorical_features)
90+
predict(est, X_test, y_test)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import argparse
2+
from time import time
3+
4+
from sklearn.preprocessing import KBinsDiscretizer
5+
from sklearn.datasets import make_classification
6+
from sklearn.experimental import enable_hist_gradient_boosting # noqa
7+
from sklearn.ensemble import HistGradientBoostingClassifier
8+
from sklearn.ensemble._hist_gradient_boosting.utils import (
9+
get_equivalent_estimator)
10+
11+
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument('--n-leaf-nodes', type=int, default=31)
14+
parser.add_argument('--n-trees', type=int, default=100)
15+
parser.add_argument('--n-features', type=int, default=20)
16+
parser.add_argument('--n-cats', type=int, default=20)
17+
parser.add_argument('--n-samples', type=int, default=10_000)
18+
parser.add_argument('--lightgbm', action="store_true", default=False)
19+
parser.add_argument('--learning-rate', type=float, default=.1)
20+
parser.add_argument('--max-bins', type=int, default=255)
21+
parser.add_argument('--no-predict', action="store_true", default=False)
22+
parser.add_argument('--verbose', action="store_true", default=False)
23+
args = parser.parse_args()
24+
25+
n_leaf_nodes = args.n_leaf_nodes
26+
n_features = args.n_features
27+
n_categories = args.n_cats
28+
n_samples = args.n_samples
29+
n_trees = args.n_trees
30+
lr = args.learning_rate
31+
max_bins = args.max_bins
32+
verbose = args.verbose
33+
34+
35+
def fit(est, data_train, target_train, libname, **fit_params):
36+
print(f"Fitting a {libname} model...")
37+
tic = time()
38+
est.fit(data_train, target_train, **fit_params)
39+
toc = time()
40+
print(f"fitted in {toc - tic:.3f}s")
41+
42+
43+
def predict(est, data_test):
44+
# We don't report accuracy or ROC because the dataset doesn't really make
45+
# sense: we treat ordered features as un-ordered categories.
46+
if args.no_predict:
47+
return
48+
tic = time()
49+
est.predict(data_test)
50+
toc = time()
51+
print(f"predicted in {toc - tic:.3f}s")
52+
53+
54+
X, y = make_classification(n_samples=n_samples, n_features=n_features,
55+
random_state=0)
56+
57+
X = KBinsDiscretizer(n_bins=n_categories, encode='ordinal').fit_transform(X)
58+
59+
print(f"Number of features: {n_features}")
60+
print(f"Number of samples: {n_samples}")
61+
62+
is_categorical = [True] * n_features
63+
est = HistGradientBoostingClassifier(
64+
loss='binary_crossentropy',
65+
learning_rate=lr,
66+
max_iter=n_trees,
67+
max_bins=max_bins,
68+
max_leaf_nodes=n_leaf_nodes,
69+
categorical_features=is_categorical,
70+
early_stopping=False,
71+
random_state=0,
72+
verbose=verbose
73+
)
74+
75+
fit(est, X, y, 'sklearn')
76+
predict(est, X)
77+
78+
if args.lightgbm:
79+
est = get_equivalent_estimator(est, lib='lightgbm')
80+
est.set_params(max_cat_to_onehot=1) # dont use OHE
81+
categorical_features = list(range(n_features))
82+
fit(est, X, y, 'lightgbm',
83+
categorical_feature=categorical_features)
84+
predict(est, X)

doc/modules/ensemble.rst

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,68 @@ multiplying the gradients (and the hessians) by the sample weights. Note that
10511051
the binning stage (specifically the quantiles computation) does not take the
10521052
weights into account.
10531053

1054+
.. _categorical_support_gbdt:
1055+
1056+
Categorical Features Support
1057+
----------------------------
1058+
1059+
:class:`HistGradientBoostingClassifier` and
1060+
:class:`HistGradientBoostingRegressor` have native support for categorical
1061+
features: they can consider splits on non-ordered, categorical data.
1062+
1063+
For datasets with categorical features, using the native categorical support
1064+
is often better than relying on one-hot encoding
1065+
(:class:`~sklearn.preprocessing.OneHotEncoder`), because one-hot encoding
1066+
requires more tree depth to achieve equivalent splits. It is also usually
1067+
better to rely on the native categorical support rather than to treat
1068+
categorical features as continuous (ordinal), which happens for ordinal-encoded
1069+
categorical data, since categories are nominal quantities where order does not
1070+
matter.
1071+
1072+
To enable categorical support, a boolean mask can be passed to the
1073+
`categorical_features` parameter, indicating which feature is categorical. In
1074+
the following, the first feature will be treated as categorical and the
1075+
second feature as numerical::
1076+
1077+
>>> gbdt = HistGradientBoostingClassifier(categorical_features=[True, False])
1078+
1079+
Equivalently, one can pass a list of integers indicating the indices of the
1080+
categorical features::
1081+
1082+
>>> gbdt = HistGradientBoostingClassifier(categorical_features=[0])
1083+
1084+
The cardinality of each categorical feature should be less than the `max_bins`
1085+
parameter, and each categorical feature is expected to be encoded in
1086+
`[0, max_bins - 1]`. To that end, it might be useful to pre-process the data
1087+
with an :class:`~sklearn.preprocessing.OrdinalEncoder` as done in
1088+
:ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_categorical.py`.
1089+
1090+
If there are missing values during training, the missing values will be
1091+
treated as a proper category. If there are no missing values during training,
1092+
then at prediction time, missing values are mapped to the child node that has
1093+
the most samples (just like for continuous features). When predicting,
1094+
categories that were not seen during fit time will be treated as missing
1095+
values.
1096+
1097+
**Split finding with categorical features**: The canonical way of considering
1098+
categorical splits in a tree is to consider
1099+
all of the :math:`2^{K - 1} - 1` partitions, where :math:`K` is the number of
1100+
categories. This can quickly become prohibitive when :math:`K` is large.
1101+
Fortunately, since gradient boosting trees are always regression trees (even
1102+
for classification problems), there exist a faster strategy that can yield
1103+
equivalent splits. First, the categories of a feature are sorted according to
1104+
the variance of the target, for each category `k`. Once the categories are
1105+
sorted, one can consider *continuous partitions*, i.e. treat the categories
1106+
as if they were ordered continuous values (see Fisher [Fisher1958]_ for a
1107+
formal proof). As a result, only :math:`K - 1` splits need to be considered
1108+
instead of :math:`2^{K - 1} - 1`. The initial sorting is a
1109+
:math:`\mathcal{O}(K \log(K))` operation, leading to a total complexity of
1110+
:math:`\mathcal{O}(K \log(K) + K)`, instead of :math:`\mathcal{O}(2^K)`.
1111+
1112+
.. topic:: Examples:
1113+
1114+
* :ref:`sphx_glr_auto_examples_ensemble_plot_gradient_boosting_categorical.py`
1115+
10541116
.. _monotonic_cst_gbdt:
10551117

10561118
Monotonic Constraints
@@ -1092,6 +1154,10 @@ that the feature is supposed to have a positive / negative effect on the
10921154
probability to belong to the positive class. Monotonic constraints are not
10931155
supported for multiclass context.
10941156

1157+
.. note::
1158+
Since categories are unordered quantities, it is not possible to enforce
1159+
monotonic constraints on categorical features.
1160+
10951161
.. topic:: Examples:
10961162

10971163
* :ref:`sphx_glr_auto_examples_ensemble_plot_monotonic_constraints.py`
@@ -1158,6 +1224,8 @@ Finally, many parts of the implementation of
11581224
.. [LightGBM] Ke et. al. `"LightGBM: A Highly Efficient Gradient
11591225
BoostingDecision Tree" <https://papers.nips.cc/paper/
11601226
6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree>`_
1227+
.. [Fisher1958] Walter D. Fisher. `"On Grouping for Maximum Homogeneity"
1228+
<http://www.csiss.org/SPACE/workshops/2004/SAC/files/fisher.pdf>`_
11611229
11621230
.. _voting_classifier:
11631231

doc/whats_new/v0.24.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,11 @@ Changelog
242242
:mod:`sklearn.ensemble`
243243
.......................
244244

245+
- |MajorFeature| :class:`ensemble.HistGradientBoostingRegressor` and
246+
:class:`ensemble.HistGradientBoostingClassifier` now have native
247+
support for categorical features with the `categorical_features`
248+
parameter. :pr:`18394` by `Nicolas Hug`_ and `Thomas Fan`_.
249+
245250
- |Feature| :class:`ensemble.HistGradientBoostingRegressor` and
246251
:class:`ensemble.HistGradientBoostingClassifier` now support the
247252
method `staged_predict`, which allows monitoring of each stage.

0 commit comments

Comments
 (0)