Skip to content

Commit 6f9339a

Browse files
committed
TST fix import in test
1 parent 2881702 commit 6f9339a

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

sklearn/svm/tests/test_svm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,13 @@ def test_auto_weight():
336336
# we take as dataset a the two-dimensional projection of iris so
337337
# that it is not separable and remove half of predictors from
338338
# class 1
339-
from sklearn.svm.base import _get_class_weight
339+
from sklearn.utils import compute_class_weight
340340
X, y = iris.data[:, :2], iris.target
341341
unbalanced = np.delete(np.arange(y.size), np.where(y > 1)[0][::2])
342-
343-
assert_true(np.argmax(_get_class_weight('auto', y[unbalanced])[0]) == 2)
342+
343+
classes = np.unique(y[unbalanced])
344+
class_weights = compute_class_weight('auto', classes, y[unbalanced])
345+
assert_true(np.argmax(class_weights) == 2)
344346

345347
for clf in (svm.SVC(kernel='linear'), svm.LinearSVC(random_state=0),
346348
LogisticRegression()):

0 commit comments

Comments
 (0)