14
14
from ._binner import Binner
15
15
16
16
17
+ def _get_predictor_kwargs (predictor_kwargs , ** kwargs ) -> dict :
18
+ """Overwrites default args if predictor_kwargs is supplied."""
19
+ for key , value in kwargs .items ():
20
+ if key not in predictor_kwargs .keys ():
21
+ predictor_kwargs [key ] = value
22
+ return predictor_kwargs
23
+
24
+
17
25
def _build_predictor (
18
26
predictor_name ,
19
27
n_estimators ,
20
28
n_outputs ,
21
29
max_depth = None ,
22
30
min_samples_leaf = 1 ,
23
31
n_jobs = None ,
24
- random_state = None
32
+ random_state = None ,
33
+ predictor_kwargs = {},
25
34
):
26
35
"""Build the predictor concatenated to the deep forest."""
27
36
predictor_name = predictor_name .lower ()
@@ -30,11 +39,14 @@ def _build_predictor(
30
39
if predictor_name == "forest" :
31
40
from .forest import RandomForestClassifier
32
41
predictor = RandomForestClassifier (
33
- n_estimators = n_estimators ,
34
- max_depth = max_depth ,
35
- min_samples_leaf = min_samples_leaf ,
36
- n_jobs = n_jobs ,
37
- random_state = random_state ,
42
+ ** _get_predictor_kwargs (
43
+ predictor_kwargs ,
44
+ n_estimators = n_estimators ,
45
+ max_depth = max_depth ,
46
+ min_samples_leaf = min_samples_leaf ,
47
+ n_jobs = n_jobs ,
48
+ random_state = random_state ,
49
+ )
38
50
)
39
51
# XGBoost
40
52
elif predictor_name == "xgboost" :
@@ -51,11 +63,14 @@ def _build_predictor(
51
63
# because the exact mode of XGBoost is too slow.
52
64
objective = "multi:softmax" if n_outputs > 2 else "binary:logistic"
53
65
predictor = xgb .sklearn .XGBClassifier (
54
- objective = objective ,
55
- n_estimators = n_estimators ,
56
- tree_method = "hist" ,
57
- n_jobs = n_jobs ,
58
- random_state = random_state ,
66
+ ** _get_predictor_kwargs (
67
+ predictor_kwargs ,
68
+ objective = objective ,
69
+ n_estimators = n_estimators ,
70
+ tree_method = "hist" ,
71
+ n_jobs = n_jobs ,
72
+ random_state = random_state ,
73
+ )
59
74
)
60
75
# LightGBM
61
76
elif predictor_name == "lightgbm" :
@@ -70,10 +85,13 @@ def _build_predictor(
70
85
71
86
objective = "multiclass" if n_outputs > 2 else "binary"
72
87
predictor = lgb .LGBMClassifier (
73
- objective = objective ,
74
- n_estimators = n_estimators ,
75
- n_jobs = n_jobs ,
76
- random_state = random_state ,
88
+ ** _get_predictor_kwargs (
89
+ predictor_kwargs ,
90
+ objective = objective ,
91
+ n_estimators = n_estimators ,
92
+ n_jobs = n_jobs ,
93
+ random_state = random_state ,
94
+ )
77
95
)
78
96
else :
79
97
msg = (
@@ -117,6 +135,11 @@ def _build_predictor(
117
135
predictor : :obj:`{"forest", "xgboost", "lightgbm"}`, default="forest"
118
136
The type of the predictor concatenated to the deep forest. If
119
137
``use_predictor`` is False, this parameter will have no effect.
138
+ predictor_kwargs : :obj:`dict`, default={}
139
+ The configuration of the predictor concatenated to the deep forest.
140
+ Specifying this will extend/overwrite the original parameters inherit
141
+ from deep forest.
142
+ If ``use_predictor`` is False, this parameter will have no effect.
120
143
n_tolerant_rounds : :obj:`int`, default=2
121
144
Specify when to conduct early stopping. The training process
122
145
terminates when the validation performance on the training set does
@@ -182,6 +205,7 @@ def __init__(
182
205
min_samples_leaf = 1 ,
183
206
use_predictor = False ,
184
207
predictor = "forest" ,
208
+ predictor_kwargs = {},
185
209
n_tolerant_rounds = 2 ,
186
210
delta = 1e-5 ,
187
211
partial_mode = False ,
@@ -197,6 +221,7 @@ def __init__(
197
221
self .n_trees = n_trees
198
222
self .max_depth = max_depth
199
223
self .min_samples_leaf = min_samples_leaf
224
+ self .predictor_kwargs = predictor_kwargs
200
225
self .n_tolerant_rounds = n_tolerant_rounds
201
226
self .delta = delta
202
227
self .partial_mode = partial_mode
@@ -618,7 +643,8 @@ def fit(self, X, y):
618
643
self .max_depth ,
619
644
self .min_samples_leaf ,
620
645
self .n_jobs ,
621
- self .random_state
646
+ self .random_state ,
647
+ self .predictor_kwargs ,
622
648
)
623
649
624
650
binner_ = Binner (
0 commit comments