@@ -129,17 +129,6 @@ def _pprint(params, offset=0, printer=repr):
129129 return lines
130130
131131
132- def _update_if_consistent (dict1 , dict2 ):
133- common_keys = set (dict1 .keys ()).intersection (dict2 .keys ())
134- for key in common_keys :
135- if dict1 [key ] != dict2 [key ]:
136- raise TypeError ("Inconsistent values for tag {}: {} != {}" .format (
137- key , dict1 [key ], dict2 [key ]
138- ))
139- dict1 .update (dict2 )
140- return dict1
141-
142-
143132class BaseEstimator :
144133 """Base class for all estimators in scikit-learn
145134
@@ -320,20 +309,19 @@ def __setstate__(self, state):
320309 except AttributeError :
321310 self .__dict__ .update (state )
322311
312+ def _more_tags (self ):
313+ return _DEFAULT_TAGS
314+
323315 def _get_tags (self ):
324316 collected_tags = {}
325- for base_class in inspect .getmro (self .__class__ ):
326- if (hasattr (base_class , '_more_tags' )
327- and base_class != self .__class__ ):
317+ for base_class in reversed (inspect .getmro (self .__class__ )):
318+ if hasattr (base_class , '_more_tags' ):
319+ # need the if because mixins might not have _more_tags
320+ # but might do redundant work in estimators
321+ # (i.e. calling more tags on BaseEstimator multiple times)
328322 more_tags = base_class ._more_tags (self )
329- collected_tags = _update_if_consistent (collected_tags ,
330- more_tags )
331- if hasattr (self , '_more_tags' ):
332- more_tags = self ._more_tags ()
333- collected_tags = _update_if_consistent (collected_tags , more_tags )
334- tags = _DEFAULT_TAGS .copy ()
335- tags .update (collected_tags )
336- return tags
323+ collected_tags .update (more_tags )
324+ return collected_tags
337325
338326
339327class ClassifierMixin :
0 commit comments