|
15 | 15 | from ..ensemble._hist_gradient_boosting.gradient_boosting import ( |
16 | 16 | BaseHistGradientBoosting, |
17 | 17 | ) |
18 | | -from ..exceptions import NotFittedError |
19 | 18 | from ..tree import DecisionTreeRegressor |
20 | 19 | from ..utils import Bunch, _safe_indexing, check_array |
21 | 20 | from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_assign |
|
27 | 26 | StrOptions, |
28 | 27 | validate_params, |
29 | 28 | ) |
| 29 | +from ..utils._response import _get_response_values |
30 | 30 | from ..utils.extmath import cartesian |
31 | 31 | from ..utils.validation import _check_sample_weight, check_is_fitted |
32 | 32 | from ._pd_utils import _check_feature_names, _get_feature_index |
@@ -261,51 +261,27 @@ def _partial_dependence_brute( |
261 | 261 | predictions = [] |
262 | 262 | averaged_predictions = [] |
263 | 263 |
|
264 | | - # define the prediction_method (predict, predict_proba, decision_function). |
265 | | - if is_regressor(est): |
266 | | - prediction_method = est.predict |
267 | | - else: |
268 | | - predict_proba = getattr(est, "predict_proba", None) |
269 | | - decision_function = getattr(est, "decision_function", None) |
270 | | - if response_method == "auto": |
271 | | - # try predict_proba, then decision_function if it doesn't exist |
272 | | - prediction_method = predict_proba or decision_function |
273 | | - else: |
274 | | - prediction_method = ( |
275 | | - predict_proba |
276 | | - if response_method == "predict_proba" |
277 | | - else decision_function |
278 | | - ) |
279 | | - if prediction_method is None: |
280 | | - if response_method == "auto": |
281 | | - raise ValueError( |
282 | | - "The estimator has no predict_proba and no " |
283 | | - "decision_function method." |
284 | | - ) |
285 | | - elif response_method == "predict_proba": |
286 | | - raise ValueError("The estimator has no predict_proba method.") |
287 | | - else: |
288 | | - raise ValueError("The estimator has no decision_function method.") |
| 264 | + if response_method == "auto": |
| 265 | + response_method = ( |
| 266 | + "predict" if is_regressor(est) else ["predict_proba", "decision_function"] |
| 267 | + ) |
289 | 268 |
|
290 | 269 | X_eval = X.copy() |
291 | 270 | for new_values in grid: |
292 | 271 | for i, variable in enumerate(features): |
293 | 272 | _safe_assign(X_eval, new_values[i], column_indexer=variable) |
294 | 273 |
|
295 | | - try: |
296 | | - # Note: predictions is of shape |
297 | | - # (n_points,) for non-multioutput regressors |
298 | | - # (n_points, n_tasks) for multioutput regressors |
299 | | - # (n_points, 1) for the regressors in cross_decomposition (I think) |
300 | | - # (n_points, 2) for binary classification |
301 | | - # (n_points, n_classes) for multiclass classification |
302 | | - pred = prediction_method(X_eval) |
303 | | - |
304 | | - predictions.append(pred) |
305 | | - # average over samples |
306 | | - averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight)) |
307 | | - except NotFittedError as e: |
308 | | - raise ValueError("'estimator' parameter must be a fitted estimator") from e |
| 274 | + # Note: predictions is of shape |
| 275 | + # (n_points,) for non-multioutput regressors |
| 276 | + # (n_points, n_tasks) for multioutput regressors |
| 277 | + # (n_points, 1) for the regressors in cross_decomposition (I think) |
| 278 | + # (n_points, 2) for binary classification |
| 279 | + # (n_points, n_classes) for multiclass classification |
| 280 | + pred, _ = _get_response_values(est, X_eval, response_method=response_method) |
| 281 | + |
| 282 | + predictions.append(pred) |
| 283 | + # average over samples |
| 284 | + averaged_predictions.append(np.average(pred, axis=0, weights=sample_weight)) |
309 | 285 |
|
310 | 286 | n_samples = X.shape[0] |
311 | 287 |
|
|
0 commit comments