Skip to content

Commit c2d2307

Browse files
kyleabeauchampraghavrv
authored andcommitted
Added feature to fix scikit-learn#1523
1 parent a3283c6 commit c2d2307

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ Enhancements
152152

153153
- Add ``n_iter_`` attribute to estimators that accept a ``max_iter`` attribute
154154
in their constructor. By `Manoj Kumar`_.
155+
156+
- Added decision function for :class:`multiclass.OneVsOneClassifier`
157+
By `Raghav R V`_ and `Kyle Beauchamp`_.
155158

156159
- :func:`neighbors.kneighbors_graph` and :func:`radius_neighbors_graph`
157160
support non-Euclidean metrics. By `Manoj Kumar`_

sklearn/multiclass.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,35 @@ def predict(self, X):
551551
prediction = votes.argmax(axis=1)
552552
return self.classes_[prediction]
553553

554+
def decision_function(self, X):
555+
"""Distance (votes) of the samples X to the separating hyperplanes.
556+
Parameters
557+
----------
558+
X : array-like, shape = [n_samples, n_features]
559+
Returns
560+
-------
561+
D : array-like, shape = [n_samples, n_class * (n_class-1) / 2]
562+
Returns the decision function of the sample for each class
563+
in the model.
564+
"""
565+
if not hasattr(self, "estimators_"):
566+
raise ValueError("The object hasn't been fitted yet!")
567+
568+
# Predict decision function (votes) using the one-vs-one strategy.
569+
n_samples = X.shape[0]
570+
n_classes = self.classes_.shape[0]
571+
votes = np.zeros((n_samples, n_classes))
572+
573+
k = 0
574+
for i in range(n_classes):
575+
for j in range(i + 1, n_classes):
576+
pred = self.estimators_[k].predict(X)
577+
votes[pred == 0, i] += 1
578+
votes[pred == 1, j] += 1
579+
k += 1
580+
581+
return votes
582+
554583

555584
@deprecated("fit_ecoc is deprecated and will be removed in 0.18."
556585
"Use the OutputCodeClassifier instead.")

0 commit comments

Comments
 (0)