Skip to content

Commit 92c4308

Browse files
mjbommarjnothman
authored andcommitted
ENH Dense pipeline support for RandomTreesEmbedding via sparse_output param
1 parent b041039 commit 92c4308

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

sklearn/ensemble/forest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1271,6 +1271,10 @@ class RandomTreesEmbedding(BaseForest):
12711271
If not None then ``max_depth`` will be ignored.
12721272
Note: this parameter is tree-specific.
12731273
1274+
sparse_output: bool, optional (default=True)
1275+
Whether or not to return a sparse CSR matrix, as default behavior,
1276+
or to return a dense array compatible with dense pipeline operators.
1277+
12741278
n_jobs : integer, optional (default=1)
12751279
The number of jobs to run in parallel for both `fit` and `predict`.
12761280
If -1, then the number of jobs is set to the number of cores.
@@ -1305,6 +1309,7 @@ def __init__(self,
13051309
min_samples_split=2,
13061310
min_samples_leaf=1,
13071311
max_leaf_nodes=None,
1312+
sparse_output=True,
13081313
n_jobs=1,
13091314
random_state=None,
13101315
verbose=0,
@@ -1327,6 +1332,7 @@ def __init__(self,
13271332
self.min_samples_leaf = min_samples_leaf
13281333
self.max_features = 1
13291334
self.max_leaf_nodes = max_leaf_nodes
1335+
self.sparse_output = sparse_output
13301336

13311337
if min_density is not None:
13321338
warn("The min_density parameter is deprecated as of version 0.14 "
@@ -1363,7 +1369,7 @@ def fit_transform(self, X, y=None):
13631369
rnd = check_random_state(self.random_state)
13641370
y = rnd.uniform(size=X.shape[0])
13651371
super(RandomTreesEmbedding, self).fit(X, y)
1366-
self.one_hot_encoder_ = OneHotEncoder()
1372+
self.one_hot_encoder_ = OneHotEncoder(sparse=self.sparse_output)
13671373
return self.one_hot_encoder_.fit_transform(self.apply(X))
13681374

13691375
def transform(self, X):

sklearn/ensemble/tests/test_forest.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,41 @@ def test_classes_shape():
377377
yield check_classes_shape, name
378378

379379

380+
def test_random_trees_dense_type():
381+
'''
382+
Test that the `sparse_output` parameter of RandomTreesEmbedding
383+
works by returning a dense array.
384+
'''
385+
386+
# Create the RTE with sparse=False
387+
hasher = RandomTreesEmbedding(n_estimators=10, sparse_output=False)
388+
X, y = datasets.make_circles(factor=0.5)
389+
X_transformed = hasher.fit_transform(X)
390+
391+
# Assert that type is ndarray, not scipy.sparse.csr.csr_matrix
392+
assert_equal(type(X_transformed), np.ndarray)
393+
394+
395+
def test_random_trees_dense_equal():
396+
'''
397+
Test that the `sparse_output` parameter of RandomTreesEmbedding
398+
works by returning the same array for both argument
399+
values.
400+
'''
401+
402+
# Create the RTEs
403+
hasher_dense = RandomTreesEmbedding(n_estimators=10, sparse_output=False,
404+
random_state=0)
405+
hasher_sparse = RandomTreesEmbedding(n_estimators=10, sparse_output=True,
406+
random_state=0)
407+
X, y = datasets.make_circles(factor=0.5)
408+
X_transformed_dense = hasher_dense.fit_transform(X)
409+
X_transformed_sparse = hasher_sparse.fit_transform(X)
410+
411+
# Assert that dense and sparse hashers have same array.
412+
assert_array_equal(X_transformed_sparse.toarray(), X_transformed_dense)
413+
414+
380415
def test_random_hasher():
381416
# test random forest hashing on circles dataset
382417
# make sure that it is linearly separable.

0 commit comments

Comments
 (0)