Skip to content

Commit 03ea20d

Browse files
amuellerrth
authored andcommitted
Fix mixin inheritance order, allow overwriting tags (scikit-learn#14884)
1 parent 96bfae6 commit 03ea20d

File tree

90 files changed

+170
-176
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

90 files changed

+170
-176
lines changed

examples/compose/plot_column_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from sklearn.svm import LinearSVC
4343

4444

45-
class TextStats(BaseEstimator, TransformerMixin):
45+
class TextStats(TransformerMixin, BaseEstimator):
4646
"""Extract features from each document for DictVectorizer"""
4747

4848
def fit(self, x, y=None):
@@ -54,7 +54,7 @@ def transform(self, posts):
5454
for text in posts]
5555

5656

57-
class SubjectBodyExtractor(BaseEstimator, TransformerMixin):
57+
class SubjectBodyExtractor(TransformerMixin, BaseEstimator):
5858
"""Extract the subject & body from a usenet post in a single pass.
5959
6060
Takes a sequence of strings and produces a dict of sequences. Keys are

sklearn/base.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -129,17 +129,6 @@ def _pprint(params, offset=0, printer=repr):
129129
return lines
130130

131131

132-
def _update_if_consistent(dict1, dict2):
133-
common_keys = set(dict1.keys()).intersection(dict2.keys())
134-
for key in common_keys:
135-
if dict1[key] != dict2[key]:
136-
raise TypeError("Inconsistent values for tag {}: {} != {}".format(
137-
key, dict1[key], dict2[key]
138-
))
139-
dict1.update(dict2)
140-
return dict1
141-
142-
143132
class BaseEstimator:
144133
"""Base class for all estimators in scikit-learn
145134
@@ -320,20 +309,19 @@ def __setstate__(self, state):
320309
except AttributeError:
321310
self.__dict__.update(state)
322311

312+
def _more_tags(self):
313+
return _DEFAULT_TAGS
314+
323315
def _get_tags(self):
324316
collected_tags = {}
325-
for base_class in inspect.getmro(self.__class__):
326-
if (hasattr(base_class, '_more_tags')
327-
and base_class != self.__class__):
317+
for base_class in reversed(inspect.getmro(self.__class__)):
318+
if hasattr(base_class, '_more_tags'):
319+
# need the if because mixins might not have _more_tags
320+
# but might do redundant work in estimators
321+
# (i.e. calling more tags on BaseEstimator multiple times)
328322
more_tags = base_class._more_tags(self)
329-
collected_tags = _update_if_consistent(collected_tags,
330-
more_tags)
331-
if hasattr(self, '_more_tags'):
332-
more_tags = self._more_tags()
333-
collected_tags = _update_if_consistent(collected_tags, more_tags)
334-
tags = _DEFAULT_TAGS.copy()
335-
tags.update(collected_tags)
336-
return tags
323+
collected_tags.update(more_tags)
324+
return collected_tags
337325

338326

339327
class ClassifierMixin:

sklearn/calibration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def grad(AB):
465465
return AB_[0], AB_[1]
466466

467467

468-
class _SigmoidCalibration(BaseEstimator, RegressorMixin):
468+
class _SigmoidCalibration(RegressorMixin, BaseEstimator):
469469
"""Sigmoid regression model.
470470
471471
Attributes

sklearn/cluster/affinity_propagation_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def affinity_propagation(S, preference=None, convergence_iter=15, max_iter=200,
233233

234234
###############################################################################
235235

236-
class AffinityPropagation(BaseEstimator, ClusterMixin):
236+
class AffinityPropagation(ClusterMixin, BaseEstimator):
237237
"""Perform Affinity Propagation Clustering of data.
238238
239239
Read more in the :ref:`User Guide <affinity_propagation>`.

sklearn/cluster/bicluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def _log_normalize(X):
8484
return L - row_avg - col_avg + avg
8585

8686

87-
class BaseSpectral(BaseEstimator, BiclusterMixin, metaclass=ABCMeta):
87+
class BaseSpectral(BiclusterMixin, BaseEstimator, metaclass=ABCMeta):
8888
"""Base class for spectral biclustering."""
8989

9090
@abstractmethod

sklearn/cluster/birch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def radius(self):
319319
self.sq_norm_)
320320

321321

322-
class Birch(BaseEstimator, TransformerMixin, ClusterMixin):
322+
class Birch(ClusterMixin, TransformerMixin, BaseEstimator):
323323
"""Implements the Birch clustering algorithm.
324324
325325
It is a memory-efficient, online-learning algorithm provided as an

sklearn/cluster/dbscan_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def dbscan(X, eps=0.5, min_samples=5, metric='minkowski', metric_params=None,
190190
return np.where(core_samples)[0], labels
191191

192192

193-
class DBSCAN(BaseEstimator, ClusterMixin):
193+
class DBSCAN(ClusterMixin, BaseEstimator):
194194
"""Perform DBSCAN clustering from vector array or distance matrix.
195195
196196
DBSCAN - Density-Based Spatial Clustering of Applications with Noise.

sklearn/cluster/hierarchical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def _hc_cut(n_clusters, children, n_leaves):
652652

653653
###############################################################################
654654

655-
class AgglomerativeClustering(BaseEstimator, ClusterMixin):
655+
class AgglomerativeClustering(ClusterMixin, BaseEstimator):
656656
"""
657657
Agglomerative Clustering
658658

sklearn/cluster/k_means_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def _init_centroids(X, k, init, random_state=None, x_squared_norms=None,
761761
return centers
762762

763763

764-
class KMeans(BaseEstimator, ClusterMixin, TransformerMixin):
764+
class KMeans(TransformerMixin, ClusterMixin, BaseEstimator):
765765
"""K-Means clustering
766766
767767
Read more in the :ref:`User Guide <k_means>`.

sklearn/cluster/mean_shift_.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ def get_bin_seeds(X, bin_size, min_bin_freq=1):
293293
return bin_seeds
294294

295295

296-
class MeanShift(BaseEstimator, ClusterMixin):
296+
class MeanShift(ClusterMixin, BaseEstimator):
297297
"""Mean shift clustering using a flat kernel.
298298
299299
Mean shift clustering aims to discover "blobs" in a smooth density of

0 commit comments

Comments
 (0)