Skip to content

Commit 24e4641

Browse files
phausamannjnothman
authored andcommitted
FIX: clone behavior for estimator types (scikit-learn#12585)
Fixes scikit-learn#12521
1 parent c47c8a9 commit 24e4641

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

sklearn/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def clone(estimator, safe=True):
4848
# XXX: not handling dictionaries
4949
if estimator_type in (list, tuple, set, frozenset):
5050
return estimator_type([clone(e, safe=safe) for e in estimator])
51-
elif not hasattr(estimator, 'get_params'):
51+
elif not hasattr(estimator, 'get_params') or isinstance(estimator, type):
5252
if not safe:
5353
return copy.deepcopy(estimator)
5454
else:

sklearn/tests/test_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ def test_clone_sparse_matrices():
167167
assert_array_equal(clf.empty.toarray(), clf_cloned.empty.toarray())
168168

169169

170+
def test_clone_estimator_types():
171+
# Check that clone works for parameters that are types rather than
172+
# instances
173+
clf = MyEstimator(empty=MyEstimator)
174+
clf2 = clone(clf)
175+
176+
assert clf.empty is clf2.empty
177+
178+
170179
def test_repr():
171180
# Smoke test the repr of the base estimator.
172181
my_estimator = MyEstimator()

0 commit comments

Comments
 (0)