|
4 | 4 | ============================================================================= |
5 | 5 |
|
6 | 6 | An illustration of various embeddings on the digits dataset. |
| 7 | +
|
| 8 | +The RandomForestEmbedding, from the :mod:`sklearn.ensemble` module, is not |
| 9 | +technically a manifold embedding method, as it learn a high-dimensional |
| 10 | +representation on wich we apply a dimensionality reduction method. |
| 11 | +However, it is often useful to cast a dataset into a representation in |
| 12 | +which the classes are linearly-seperable. |
7 | 13 | """ |
8 | 14 |
|
9 | 15 | # Authors: Fabian Pedregosa <[email protected]> |
10 | 16 | # Olivier Grisel <[email protected]> |
11 | 17 | # Mathieu Blondel <[email protected]> |
| 18 | +# Gael Varoquaux |
12 | 19 | # License: BSD, (C) INRIA 2011 |
13 | 20 |
|
14 | 21 | print __doc__ |
|
18 | 25 | import pylab as pl |
19 | 26 | from matplotlib import offsetbox |
20 | 27 | from sklearn.utils.fixes import qr_economic |
21 | | -from sklearn import manifold, datasets, decomposition, lda |
| 28 | +from sklearn import manifold, datasets, decomposition, ensemble, lda |
22 | 29 | from sklearn.metrics import euclidean_distances |
23 | 30 |
|
24 | 31 | digits = datasets.load_digits(n_class=6) |
@@ -179,4 +186,18 @@ def plot_embedding(X, title=None): |
179 | 186 | "MDS embedding of the digits (time %.2fs)" % |
180 | 187 | (time() - t0)) |
181 | 188 |
|
| 189 | +#---------------------------------------------------------------------- |
| 190 | +# Random Forest embedding of the digits dataset |
| 191 | +print "Computing Random Forest embedding" |
| 192 | +hasher = ensemble.RandomForestEmbedding(n_estimators=200, random_state=0, |
| 193 | + max_depth=5) |
| 194 | +t0 = time() |
| 195 | +X_transformed = hasher.fit_transform(X) |
| 196 | +pca = decomposition.RandomizedPCA(n_components=2) |
| 197 | +X_reduced = pca.fit_transform(X_transformed) |
| 198 | + |
| 199 | +plot_embedding(X_reduced, |
| 200 | + "Random forest embedding of the digits (time %.2fs)" % |
| 201 | + (time() - t0)) |
| 202 | + |
182 | 203 | pl.show() |
0 commit comments