Skip to content

Commit b694184

Browse files
vincentpham1991jnothman
authored andcommitted
[MRG + 1] FIX bug where passing numpy array for weights raises error (Issue scikit-learn#7983) (scikit-learn#7989)
1 parent d454051 commit b694184

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

doc/whats_new.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,10 @@ Bug fixes
108108
``pandas.Series`` in their ``fit`` function. :issue:`7825` by
109109
`Kathleen Chen`_.
110110

111+
- Fix a bug where :class:`sklearn.ensemble.VotingClassifier` raises an error
112+
when a numpy array is passed in for weights. :issue:`7983` by
113+
:user:`Vincent Pham <vincentpham1991>`.
114+
111115
.. _changes_0_18_1:
112116

113117
Version 0.18.1
@@ -4830,3 +4834,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
48304834
.. _Ron Weiss: http://www.ee.columbia.edu/~ronw
48314835

48324836
.. _Kathleen Chen: https://github.com/kchen17
4837+
4838+
.. _Vincent Pham: https://github.com/vincentpham1991

sklearn/ensemble/tests/test_voting_classifier.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,20 @@ def test_sample_weight():
258258
voting='soft')
259259
msg = ('Underlying estimator \'knn\' does not support sample weights.')
260260
assert_raise_message(ValueError, msg, eclf3.fit, X, y, sample_weight)
261+
262+
263+
def test_estimator_weights_format():
264+
# Test estimator weights inputs as list and array
265+
clf1 = LogisticRegression(random_state=123)
266+
clf2 = RandomForestClassifier(random_state=123)
267+
eclf1 = VotingClassifier(estimators=[
268+
('lr', clf1), ('rf', clf2)],
269+
weights=[1, 2],
270+
voting='soft')
271+
eclf2 = VotingClassifier(estimators=[
272+
('lr', clf1), ('rf', clf2)],
273+
weights=np.array((1, 2)),
274+
voting='soft')
275+
eclf1.fit(X, y)
276+
eclf2.fit(X, y)
277+
assert_array_equal(eclf1.predict_proba(X), eclf2.predict_proba(X))

sklearn/ensemble/voting_classifier.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ def fit(self, X, y, sample_weight=None):
141141
' should be a list of (string, estimator)'
142142
' tuples')
143143

144-
if self.weights and len(self.weights) != len(self.estimators):
144+
if (self.weights is not None and
145+
len(self.weights) != len(self.estimators)):
145146
raise ValueError('Number of classifiers and weights must be equal'
146147
'; got %d weights, %d estimators'
147148
% (len(self.weights), len(self.estimators)))

0 commit comments

Comments
 (0)