Skip to content

Commit 593f4bd

Browse files
committed
ENH: RandomForestEmbedding in lle_digits example
1 parent 6b096cf commit 593f4bd

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

examples/manifold/plot_lle_digits.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,18 @@
44
=============================================================================
55
66
An illustration of various embeddings on the digits dataset.
7+
8+
The RandomForestEmbedding, from the :mod:`sklearn.ensemble` module, is not
9+
technically a manifold embedding method, as it learn a high-dimensional
10+
representation on wich we apply a dimensionality reduction method.
11+
However, it is often useful to cast a dataset into a representation in
12+
which the classes are linearly-seperable.
713
"""
814

915
# Authors: Fabian Pedregosa <[email protected]>
1016
# Olivier Grisel <[email protected]>
1117
# Mathieu Blondel <[email protected]>
18+
# Gael Varoquaux
1219
# License: BSD, (C) INRIA 2011
1320

1421
print __doc__
@@ -18,7 +25,7 @@
1825
import pylab as pl
1926
from matplotlib import offsetbox
2027
from sklearn.utils.fixes import qr_economic
21-
from sklearn import manifold, datasets, decomposition, lda
28+
from sklearn import manifold, datasets, decomposition, ensemble, lda
2229
from sklearn.metrics import euclidean_distances
2330

2431
digits = datasets.load_digits(n_class=6)
@@ -179,4 +186,18 @@ def plot_embedding(X, title=None):
179186
"MDS embedding of the digits (time %.2fs)" %
180187
(time() - t0))
181188

189+
#----------------------------------------------------------------------
190+
# Random Forest embedding of the digits dataset
191+
print "Computing Random Forest embedding"
192+
hasher = ensemble.RandomForestEmbedding(n_estimators=200, random_state=0,
193+
max_depth=5)
194+
t0 = time()
195+
X_transformed = hasher.fit_transform(X)
196+
pca = decomposition.RandomizedPCA(n_components=2)
197+
X_reduced = pca.fit_transform(X_transformed)
198+
199+
plot_embedding(X_reduced,
200+
"Random forest embedding of the digits (time %.2fs)" %
201+
(time() - t0))
202+
182203
pl.show()

0 commit comments

Comments
 (0)