1313
1414import numpy as np
1515
16- from ..base import BaseEstimator
1716from ..base import ClassifierMixin
1817from ..base import TransformerMixin
1918from ..base import clone
2019from ..preprocessing import LabelEncoder
21- from ..externals import six
2220from ..externals .joblib import Parallel , delayed
2321from ..utils .validation import has_fit_parameter , check_is_fitted
22+ from ..utils .metaestimators import _BaseComposition
2423
2524
2625def _parallel_fit_estimator (estimator , X , y , sample_weight ):
@@ -32,7 +31,7 @@ def _parallel_fit_estimator(estimator, X, y, sample_weight):
3231 return estimator
3332
3433
35- class VotingClassifier (BaseEstimator , ClassifierMixin , TransformerMixin ):
34+ class VotingClassifier (_BaseComposition , ClassifierMixin , TransformerMixin ):
3635 """Soft Voting/Majority Rule classifier for unfitted estimators.
3736
3837 .. versionadded:: 0.17
@@ -44,7 +43,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
4443 estimators : list of (string, estimator) tuples
4544 Invoking the ``fit`` method on the ``VotingClassifier`` will fit clones
4645 of those original estimators that will be stored in the class attribute
47- `self.estimators_`.
46+ ``self.estimators_``. An estimator can be set to `None` using
47+ ``set_params``.
4848
4949 voting : str, {'hard', 'soft'} (default='hard')
5050 If 'hard', uses predicted class labels for majority rule voting.
@@ -64,7 +64,8 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
6464 Attributes
6565 ----------
6666 estimators_ : list of classifiers
67- The collection of fitted sub-estimators.
67+ The collection of fitted sub-estimators as defined in ``estimators``
68+ that are not `None`.
6869
6970 classes_ : array-like, shape = [n_predictions]
7071 The classes labels.
@@ -102,11 +103,14 @@ class VotingClassifier(BaseEstimator, ClassifierMixin, TransformerMixin):
102103
103104 def __init__ (self , estimators , voting = 'hard' , weights = None , n_jobs = 1 ):
104105 self .estimators = estimators
105- self .named_estimators = dict (estimators )
106106 self .voting = voting
107107 self .weights = weights
108108 self .n_jobs = n_jobs
109109
110+ @property
111+ def named_estimators (self ):
112+ return dict (self .estimators )
113+
110114 def fit (self , X , y , sample_weight = None ):
111115 """ Fit the estimators.
112116
@@ -150,23 +154,36 @@ def fit(self, X, y, sample_weight=None):
150154 if sample_weight is not None :
151155 for name , step in self .estimators :
152156 if not has_fit_parameter (step , 'sample_weight' ):
153- raise ValueError ('Underlying estimator \' %s\' does not support'
154- ' sample weights.' % name )
155-
156- self .le_ = LabelEncoder ()
157- self .le_ .fit (y )
157+ raise ValueError ('Underlying estimator \' %s\' does not'
158+ ' support sample weights.' % name )
159+ names , clfs = zip (* self .estimators )
160+ self ._validate_names (names )
161+
162+ n_isnone = np .sum ([clf is None for _ , clf in self .estimators ])
163+ if n_isnone == len (self .estimators ):
164+ raise ValueError ('All estimators are None. At least one is '
165+ 'required to be a classifier!' )
166+ self .le_ = LabelEncoder ().fit (y )
158167 self .classes_ = self .le_ .classes_
159168 self .estimators_ = []
160169
161170 transformed_y = self .le_ .transform (y )
162171
163172 self .estimators_ = Parallel (n_jobs = self .n_jobs )(
164173 delayed (_parallel_fit_estimator )(clone (clf ), X , transformed_y ,
165- sample_weight )
166- for _ , clf in self . estimators )
174+ sample_weight )
175+ for clf in clfs if clf is not None )
167176
168177 return self
169178
179+ @property
180+ def _weights_not_none (self ):
181+ """Get the weights of not `None` estimators"""
182+ if self .weights is None :
183+ return None
184+ return [w for est , w in zip (self .estimators ,
185+ self .weights ) if est [1 ] is not None ]
186+
170187 def predict (self , X ):
171188 """ Predict class labels for X.
172189
@@ -188,11 +205,10 @@ def predict(self, X):
188205
189206 else : # 'hard' voting
190207 predictions = self ._predict (X )
191- maj = np .apply_along_axis (lambda x :
192- np .argmax (np .bincount (x ,
193- weights = self .weights )),
194- axis = 1 ,
195- arr = predictions .astype ('int' ))
208+ maj = np .apply_along_axis (
209+ lambda x : np .argmax (
210+ np .bincount (x , weights = self ._weights_not_none )),
211+ axis = 1 , arr = predictions .astype ('int' ))
196212
197213 maj = self .le_ .inverse_transform (maj )
198214
@@ -208,7 +224,8 @@ def _predict_proba(self, X):
208224 raise AttributeError ("predict_proba is not available when"
209225 " voting=%r" % self .voting )
210226 check_is_fitted (self , 'estimators_' )
211- avg = np .average (self ._collect_probas (X ), axis = 0 , weights = self .weights )
227+ avg = np .average (self ._collect_probas (X ), axis = 0 ,
228+ weights = self ._weights_not_none )
212229 return avg
213230
214231 @property
@@ -252,17 +269,42 @@ def transform(self, X):
252269 else :
253270 return self ._predict (X )
254271
272+ def set_params (self , ** params ):
273+ """ Setting the parameters for the voting classifier
274+
275+ Valid parameter keys can be listed with get_params().
276+
277+ Parameters
278+ ----------
279+ params: keyword arguments
280+ Specific parameters using e.g. set_params(parameter_name=new_value)
281+ In addition, to setting the parameters of the ``VotingClassifier``,
282+ the individual classifiers of the ``VotingClassifier`` can also be
283+ set or replaced by setting them to None.
284+
285+ Examples
286+ --------
287+ # In this example, the RandomForestClassifier is removed
288+ clf1 = LogisticRegression()
289+ clf2 = RandomForestClassifier()
290+ eclf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2)]
291+ eclf.set_params(rf=None)
292+
293+ """
294+ super (VotingClassifier , self )._set_params ('estimators' , ** params )
295+ return self
296+
255297 def get_params (self , deep = True ):
256- """Return estimator parameter names for GridSearch support"""
257- if not deep :
258- return super ( VotingClassifier , self ). get_params ( deep = False )
259- else :
260- out = super ( VotingClassifier , self ). get_params ( deep = False )
261- out . update ( self . named_estimators . copy ())
262- for name , step in six . iteritems ( self . named_estimators ):
263- for key , value in six . iteritems ( step . get_params ( deep = True )):
264- out [ '%s__%s' % ( name , key )] = value
265- return out
298+ """ Get the parameters of the VotingClassifier
299+
300+ Parameters
301+ ----------
302+ deep: bool
303+ Setting it to True gets the various classifiers and the parameters
304+ of the classifiers as well
305+ """
306+ return super ( VotingClassifier ,
307+ self ). _get_params ( 'estimators' , deep = deep )
266308
267309 def _predict (self , X ):
268310 """Collect results from clf.predict calls. """
0 commit comments