Skip to content

Commit 1e184f8

Browse files
authored
[ENH] Add configurable predictor parameters (LAMDA-NJU#10)
1 parent f82facf commit 1e184f8

File tree

3 files changed

+68
-16
lines changed

3 files changed

+68
-16
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,5 @@ Version 0.1.*
3030
.. |Enhancement| replace:: :raw-html:`<span class="badge badge-info">Enhancement</span>` :raw-latex:`{\small\sc [Enhancement]}`
3131
.. |Fix| replace:: :raw-html:`<span class="badge badge-danger">Fix</span>` :raw-latex:`{\small\sc [Fix]}`
3232
.. |API| replace:: :raw-html:`<span class="badge badge-warning">API Change</span>` :raw-latex:`{\small\sc [API Change]}`
33+
34+
- |Feature| configurable predictor parameter `#9 <https://github.com/LAMDA-NJU/Deep-Forest/issues/10>`__

deepforest/cascade.py

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,23 @@
1414
from ._binner import Binner
1515

1616

17+
def _get_predictor_kwargs(predictor_kwargs, **kwargs) -> dict:
18+
"""Overwrites default args if predictor_kwargs is supplied."""
19+
for key, value in kwargs.items():
20+
if key not in predictor_kwargs.keys():
21+
predictor_kwargs[key] = value
22+
return predictor_kwargs
23+
24+
1725
def _build_predictor(
1826
predictor_name,
1927
n_estimators,
2028
n_outputs,
2129
max_depth=None,
2230
min_samples_leaf=1,
2331
n_jobs=None,
24-
random_state=None
32+
random_state=None,
33+
predictor_kwargs={},
2534
):
2635
"""Build the predictor concatenated to the deep forest."""
2736
predictor_name = predictor_name.lower()
@@ -30,11 +39,14 @@ def _build_predictor(
3039
if predictor_name == "forest":
3140
from .forest import RandomForestClassifier
3241
predictor = RandomForestClassifier(
33-
n_estimators=n_estimators,
34-
max_depth=max_depth,
35-
min_samples_leaf=min_samples_leaf,
36-
n_jobs=n_jobs,
37-
random_state=random_state,
42+
**_get_predictor_kwargs(
43+
predictor_kwargs,
44+
n_estimators=n_estimators,
45+
max_depth=max_depth,
46+
min_samples_leaf=min_samples_leaf,
47+
n_jobs=n_jobs,
48+
random_state=random_state,
49+
)
3850
)
3951
# XGBoost
4052
elif predictor_name == "xgboost":
@@ -51,11 +63,14 @@ def _build_predictor(
5163
# because the exact mode of XGBoost is too slow.
5264
objective = "multi:softmax" if n_outputs > 2 else "binary:logistic"
5365
predictor = xgb.sklearn.XGBClassifier(
54-
objective=objective,
55-
n_estimators=n_estimators,
56-
tree_method="hist",
57-
n_jobs=n_jobs,
58-
random_state=random_state,
66+
**_get_predictor_kwargs(
67+
predictor_kwargs,
68+
objective=objective,
69+
n_estimators=n_estimators,
70+
tree_method="hist",
71+
n_jobs=n_jobs,
72+
random_state=random_state,
73+
)
5974
)
6075
# LightGBM
6176
elif predictor_name == "lightgbm":
@@ -70,10 +85,13 @@ def _build_predictor(
7085

7186
objective = "multiclass" if n_outputs > 2 else "binary"
7287
predictor = lgb.LGBMClassifier(
73-
objective=objective,
74-
n_estimators=n_estimators,
75-
n_jobs=n_jobs,
76-
random_state=random_state,
88+
**_get_predictor_kwargs(
89+
predictor_kwargs,
90+
objective=objective,
91+
n_estimators=n_estimators,
92+
n_jobs=n_jobs,
93+
random_state=random_state,
94+
)
7795
)
7896
else:
7997
msg = (
@@ -117,6 +135,11 @@ def _build_predictor(
117135
predictor : :obj:`{"forest", "xgboost", "lightgbm"}`, default="forest"
118136
The type of the predictor concatenated to the deep forest. If
119137
``use_predictor`` is False, this parameter will have no effect.
138+
predictor_kwargs : :obj:`dict`, default={}
139+
The configuration of the predictor concatenated to the deep forest.
140+
Specifying this will extend/overwrite the original parameters inherit
141+
from deep forest.
142+
If ``use_predictor`` is False, this parameter will have no effect.
120143
n_tolerant_rounds : :obj:`int`, default=2
121144
Specify when to conduct early stopping. The training process
122145
terminates when the validation performance on the training set does
@@ -182,6 +205,7 @@ def __init__(
182205
min_samples_leaf=1,
183206
use_predictor=False,
184207
predictor="forest",
208+
predictor_kwargs={},
185209
n_tolerant_rounds=2,
186210
delta=1e-5,
187211
partial_mode=False,
@@ -197,6 +221,7 @@ def __init__(
197221
self.n_trees = n_trees
198222
self.max_depth = max_depth
199223
self.min_samples_leaf = min_samples_leaf
224+
self.predictor_kwargs = predictor_kwargs
200225
self.n_tolerant_rounds = n_tolerant_rounds
201226
self.delta = delta
202227
self.partial_mode = partial_mode
@@ -618,7 +643,8 @@ def fit(self, X, y):
618643
self.max_depth,
619644
self.min_samples_leaf,
620645
self.n_jobs,
621-
self.random_state
646+
self.random_state,
647+
self.predictor_kwargs,
622648
)
623649

624650
binner_ = Binner(

tests/test_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import deepforest
99
from deepforest import CascadeForestClassifier
10+
from deepforest.cascade import _get_predictor_kwargs
1011

1112

1213
save_dir = "./tmp"
@@ -26,6 +27,7 @@
2627
"min_samples_leaf": 1,
2728
"use_predictor": True,
2829
"predictor": "forest",
30+
"predictor_kwargs": {},
2931
"n_tolerant_rounds": 2,
3032
"delta": 1e-5,
3133
"n_jobs": -1,
@@ -41,13 +43,35 @@
4143
"min_samples_leaf": 1,
4244
"use_predictor": True,
4345
"predictor": "forest",
46+
"predictor_kwargs": {},
4447
"n_tolerant_rounds": 2,
4548
"delta": 1e-5,
4649
"n_jobs": -1,
4750
"random_state": 0,
4851
"verbose": 2}
4952

5053

54+
@pytest.mark.parametrize(
55+
"test_input,expected",
56+
[
57+
(
58+
{"predictor_kwargs": {}, "n_job": 2},
59+
{"n_job": 2},
60+
),
61+
(
62+
{"predictor_kwargs": {"n_job": 3}, "n_job": 2},
63+
{"n_job": 3},
64+
),
65+
(
66+
{"predictor_kwargs": {"iter": 4}, "n_job": 2},
67+
{"iter": 4, "n_job": 2},
68+
),
69+
],
70+
)
71+
def test_predictor_kwargs_overwrite(test_input, expected):
72+
assert _get_predictor_kwargs(**test_input) == expected
73+
74+
5175
def test_model_properties_after_fitting():
5276
"""Check the model properties after fitting a deep forest model."""
5377
model = CascadeForestClassifier(**toy_kwargs)

0 commit comments

Comments
 (0)