Skip to content

Commit f3b1da3

Browse files
ENH Add decision_function, predict_proba and predict_log_proba for NearestCentroid estimator (scikit-learn#26689)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 43ab714 commit f3b1da3

File tree

4 files changed

+326
-56
lines changed

4 files changed

+326
-56
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
- Add :meth:`neighbors.NearestCentroid.decision_function`,
2+
:meth:`neighbors.NearestCentroid.predict_proba` and
3+
:meth:`neighbors.NearestCentroid.predict_log_proba`
4+
to the :class:`neighbors.NearestCentroid` estimator class.
5+
Support the case when `X` is sparse and `shrinking_threshold`
6+
is not `None` in :class:`neighbors.NearestCentroid`.
7+
By :user:`Matthew Ning <NoPenguinsLand>`

sklearn/discriminant_analysis.py

Lines changed: 88 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,84 @@ def _class_cov(X, y, priors, shrinkage=None, covariance_estimator=None):
168168
return cov
169169

170170

171+
class DiscriminantAnalysisPredictionMixin:
172+
"""Mixin class for QuadraticDiscriminantAnalysis and NearestCentroid."""
173+
174+
def decision_function(self, X):
175+
"""Apply decision function to an array of samples.
176+
177+
Parameters
178+
----------
179+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
180+
Array of samples (test vectors).
181+
182+
Returns
183+
-------
184+
y_scores : ndarray of shape (n_samples,) or (n_samples, n_classes)
185+
Decision function values related to each class, per sample.
186+
In the two-class case, the shape is `(n_samples,)`, giving the
187+
log likelihood ratio of the positive class.
188+
"""
189+
y_scores = self._decision_function(X)
190+
if len(self.classes_) == 2:
191+
return y_scores[:, 1] - y_scores[:, 0]
192+
return y_scores
193+
194+
def predict(self, X):
195+
"""Perform classification on an array of vectors `X`.
196+
197+
Returns the class label for each sample.
198+
199+
Parameters
200+
----------
201+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
202+
Input vectors, where `n_samples` is the number of samples and
203+
`n_features` is the number of features.
204+
205+
Returns
206+
-------
207+
y_pred : ndarray of shape (n_samples,)
208+
Class label for each sample.
209+
"""
210+
scores = self._decision_function(X)
211+
return self.classes_.take(scores.argmax(axis=1))
212+
213+
def predict_proba(self, X):
214+
"""Estimate class probabilities.
215+
216+
Parameters
217+
----------
218+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
219+
Input data.
220+
221+
Returns
222+
-------
223+
y_proba : ndarray of shape (n_samples, n_classes)
224+
Probability estimate of the sample for each class in the
225+
model, where classes are ordered as they are in `self.classes_`.
226+
"""
227+
return np.exp(self.predict_log_proba(X))
228+
229+
def predict_log_proba(self, X):
230+
"""Estimate log class probabilities.
231+
232+
Parameters
233+
----------
234+
X : {array-like, sparse matrix} of shape (n_samples, n_features)
235+
Input data.
236+
237+
Returns
238+
-------
239+
y_log_proba : ndarray of shape (n_samples, n_classes)
240+
Estimated log probabilities.
241+
"""
242+
scores = self._decision_function(X)
243+
log_likelihood = scores - scores.max(axis=1)[:, np.newaxis]
244+
return log_likelihood - np.log(
245+
np.exp(log_likelihood).sum(axis=1)[:, np.newaxis]
246+
)
247+
248+
171249
class LinearDiscriminantAnalysis(
172250
ClassNamePrefixFeaturesOutMixin,
173251
LinearClassifierMixin,
@@ -744,9 +822,9 @@ def decision_function(self, X):
744822
745823
Returns
746824
-------
747-
C : ndarray of shape (n_samples,) or (n_samples, n_classes)
825+
y_scores : ndarray of shape (n_samples,) or (n_samples, n_classes)
748826
Decision function values related to each class, per sample.
749-
In the two-class case, the shape is (n_samples,), giving the
827+
In the two-class case, the shape is `(n_samples,)`, giving the
750828
log likelihood ratio of the positive class.
751829
"""
752830
# Only override for the doc
@@ -758,7 +836,9 @@ def __sklearn_tags__(self):
758836
return tags
759837

760838

761-
class QuadraticDiscriminantAnalysis(ClassifierMixin, BaseEstimator):
839+
class QuadraticDiscriminantAnalysis(
840+
DiscriminantAnalysisPredictionMixin, ClassifierMixin, BaseEstimator
841+
):
762842
"""Quadratic Discriminant Analysis.
763843
764844
A classifier with a quadratic decision boundary, generated
@@ -992,14 +1072,10 @@ def decision_function(self, X):
9921072
-------
9931073
C : ndarray of shape (n_samples,) or (n_samples, n_classes)
9941074
Decision function values related to each class, per sample.
995-
In the two-class case, the shape is (n_samples,), giving the
1075+
In the two-class case, the shape is `(n_samples,)`, giving the
9961076
log likelihood ratio of the positive class.
9971077
"""
998-
dec_func = self._decision_function(X)
999-
# handle special case of two classes
1000-
if len(self.classes_) == 2:
1001-
return dec_func[:, 1] - dec_func[:, 0]
1002-
return dec_func
1078+
return super().decision_function(X)
10031079

10041080
def predict(self, X):
10051081
"""Perform classification on an array of test vectors X.
@@ -1017,9 +1093,7 @@ def predict(self, X):
10171093
C : ndarray of shape (n_samples,)
10181094
Estimated probabilities.
10191095
"""
1020-
d = self._decision_function(X)
1021-
y_pred = self.classes_.take(d.argmax(1))
1022-
return y_pred
1096+
return super().predict(X)
10231097

10241098
def predict_proba(self, X):
10251099
"""Return posterior probabilities of classification.
@@ -1034,12 +1108,9 @@ def predict_proba(self, X):
10341108
C : ndarray of shape (n_samples, n_classes)
10351109
Posterior probabilities of classification per class.
10361110
"""
1037-
values = self._decision_function(X)
10381111
# compute the likelihood of the underlying gaussian models
10391112
# up to a multiplicative constant.
1040-
likelihood = np.exp(values - values.max(axis=1)[:, np.newaxis])
1041-
# compute posterior probabilities
1042-
return likelihood / likelihood.sum(axis=1)[:, np.newaxis]
1113+
return super().predict_proba(X)
10431114

10441115
def predict_log_proba(self, X):
10451116
"""Return log of posterior probabilities of classification.
@@ -1055,5 +1126,4 @@ def predict_log_proba(self, X):
10551126
Posterior log-probabilities of classification per class.
10561127
"""
10571128
# XXX : can do better to avoid precision overflows
1058-
probas_ = self.predict_proba(X)
1059-
return np.log(probas_)
1129+
return super().predict_log_proba(X)

0 commit comments

Comments
 (0)