Skip to content

Commit 6e3472c

Browse files
committed
ENH added docs, example and tests.
1 parent 973edd5 commit 6e3472c

File tree

5 files changed

+124
-6
lines changed

5 files changed

+124
-6
lines changed

doc/modules/classes.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ Samples generator
250250
:template: class.rst
251251

252252
ensemble.RandomForestClassifier
253+
ensemble.RandomForestHasher
253254
ensemble.RandomForestRegressor
254255
ensemble.ExtraTreesClassifier
255256
ensemble.ExtraTreesRegressor

sklearn/ensemble/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .base import BaseEnsemble
77
from .forest import RandomForestClassifier
88
from .forest import RandomForestRegressor
9+
from .forest import RandomForestHasher
910
from .forest import ExtraTreesClassifier
1011
from .forest import ExtraTreesRegressor
1112
from .gradient_boosting import GradientBoostingClassifier

sklearn/ensemble/forest.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,65 @@ def __init__(self, n_estimators=10,
11591159
self.max_features = max_features
11601160

11611161

1162-
class RandomHashingForest(ExtraTreesClassifier):
1162+
class RandomForestHasher(ExtraTreesClassifier):
1163+
"""Use a completely random forest to create sparse, binary represenations.
1164+
1165+
An unsupervised transformation of a dataset to a high-dimensional
1166+
sparse representation. A datapoint is coded according to which leaf of
1167+
each tree it is sorted into. Using a one-hot encoding of the leafs,
1168+
this leads to a binary coding with as many ones as trees in the forest.
1169+
1170+
The dimensionality of the resulting representation is approximately
1171+
``n_estimators * 2 ** max_depth``.
1172+
1173+
Parameters
1174+
----------
1175+
n_estimators : int
1176+
Number of trees in the forest.
1177+
1178+
max_depth : int
1179+
Maximum depth of each tree.
1180+
1181+
min_samples_split : integer, optional (default=1)
1182+
The minimum number of samples required to split an internal node.
1183+
Note: this parameter is tree-specific.
1184+
1185+
min_samples_leaf : integer, optional (default=1)
1186+
The minimum number of samples in newly created leaves. A split is
1187+
discarded if after the split, one of the leaves would contain less then
1188+
``min_samples_leaf`` samples.
1189+
Note: this parameter is tree-specific.
1190+
1191+
min_density : float, optional (default=0.1)
1192+
This parameter controls a trade-off in an optimization heuristic. It
1193+
controls the minimum density of the `sample_mask` (i.e. the
1194+
fraction of samples in the mask). If the density falls below this
1195+
threshold the mask is recomputed and the input data is packed
1196+
which results in data copying. If `min_density` equals to one,
1197+
the partitions are always represented as copies of the original
1198+
data. Otherwise, partitions are represented as bit masks (aka
1199+
sample masks).
1200+
1201+
n_jobs : integer, optional (default=1)
1202+
The number of jobs to run in parallel. If -1, then the number of jobs
1203+
is set to the number of cores.
1204+
1205+
random_state : int, RandomState instance or None, optional (default=None)
1206+
If int, random_state is the seed used by the random number generator;
1207+
If RandomState instance, random_state is the random number generator;
1208+
If None, the random number generator is the RandomState instance used
1209+
by `np.random`.
1210+
1211+
verbose : int, optional (default=0)
1212+
Controls the verbosity of the tree building process.
1213+
1214+
Attributes
1215+
----------
1216+
`estimators_`: list of DecisionTreeClassifier
1217+
The collection of fitted sub-estimators.
1218+
1219+
"""
1220+
11631221
def __init__(self, n_estimators=10,
11641222
max_depth=5,
11651223
min_samples_split=1,
@@ -1168,7 +1226,7 @@ def __init__(self, n_estimators=10,
11681226
n_jobs=1,
11691227
random_state=None,
11701228
verbose=0):
1171-
super(RandomHashingForest, self).__init__(
1229+
super(RandomForestHasher, self).__init__(
11721230
n_estimators=n_estimators,
11731231
max_depth=max_depth,
11741232
min_samples_split=min_samples_split,
@@ -1183,14 +1241,45 @@ def __init__(self, n_estimators=10,
11831241
verbose=verbose)
11841242

11851243
def fit(self, X, y=None):
1244+
"""Fit estimator.
1245+
1246+
Parameters
1247+
----------
1248+
X : array-like, shape=(n_samples, n_features)
1249+
Input data used to build forests.
1250+
"""
11861251
self.fit_transform(X, y)
11871252
return self
11881253

11891254
def fit_transform(self, X, y=None):
1255+
"""Fit estimator and transform dataset.
1256+
1257+
Parameters
1258+
----------
1259+
X : array-like, shape=(n_samples, n_features)
1260+
Input data used to build forests.
1261+
1262+
Returns
1263+
-------
1264+
X_transformed: sparse matrix, shape=(n_samples, n_out)
1265+
Transformed dataset.
1266+
"""
11901267
y = np.arange(len(X))
1191-
super(RandomHashingForest, self).fit(X, y)
1268+
super(RandomForestHasher, self).fit(X, y)
11921269
self.one_hot_encoder_ = OneHotEncoder()
11931270
return self.one_hot_encoder_.fit_transform(self.apply(X))
11941271

11951272
def transform(self, X):
1273+
"""Transform dataset.
1274+
1275+
Parameters
1276+
----------
1277+
X : array-like, shape=(n_samples, n_features)
1278+
Input data to be transformed.
1279+
1280+
Returns
1281+
-------
1282+
X_transformed: sparse matrix, shape=(n_samples, n_out)
1283+
Transformed dataset.
1284+
"""
11961285
return self.one_hot_encoder_.transform(self.apply(X))

sklearn/ensemble/tests/test_forest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Testing for the forest module (sklearn.ensemble.forest).
33
"""
44

5-
# Authors: Gilles Louppe, Brian Holt
5+
# Authors: Gilles Louppe, Brian Holt, Andreas Mueller
66
# License: BSD 3
77

88
import numpy as np
@@ -17,8 +17,11 @@
1717
from sklearn.grid_search import GridSearchCV
1818
from sklearn.ensemble import RandomForestClassifier
1919
from sklearn.ensemble import RandomForestRegressor
20+
from sklearn.ensemble import RandomForestHasher
2021
from sklearn.ensemble import ExtraTreesClassifier
2122
from sklearn.ensemble import ExtraTreesRegressor
23+
from sklearn.svm import LinearSVC
24+
from sklearn.decomposition import RandomizedPCA
2225
from sklearn import datasets
2326

2427
# toy sample
@@ -372,6 +375,29 @@ def test_multioutput():
372375
np.seterr(**olderr)
373376

374377

378+
def test_random_hasher():
379+
# test random forest hashing on circles dataset
380+
# make sure that it is linearly separable.
381+
# even after projected to two pca dimensions
382+
hasher = RandomForestHasher(n_estimators=30, random_state=0)
383+
X, y = datasets.make_circles(factor=0.5)
384+
X_transformed = hasher.fit_transform(X)
385+
386+
# test fit and transform:
387+
hasher = RandomForestHasher(n_estimators=30, random_state=0)
388+
assert_array_equal(hasher.fit(X).transform(X).toarray(),
389+
X_transformed.toarray())
390+
391+
# one leaf active per data point per forest
392+
assert_equal(X_transformed.shape[0], X.shape[0])
393+
assert_array_equal(X_transformed.sum(axis=1), hasher.n_estimators)
394+
pca = RandomizedPCA(n_components=2)
395+
X_reduced = pca.fit_transform(X_transformed)
396+
linear_clf = LinearSVC()
397+
linear_clf.fit(X_reduced, y)
398+
assert_equal(linear_clf.score(X_reduced, y), 1.)
399+
400+
375401
if __name__ == "__main__":
376402
import nose
377403
nose.runmodule()

sklearn/tests/test_common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from sklearn.decomposition import SparseCoder
3838
from sklearn.pipeline import Pipeline, FeatureUnion
3939
from sklearn.pls import _PLS, PLSCanonical, PLSRegression, CCA, PLSSVD
40-
from sklearn.ensemble import BaseEnsemble
40+
from sklearn.ensemble import BaseEnsemble, RandomForestHasher
4141
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier,\
4242
OutputCodeClassifier
4343
from sklearn.feature_selection import RFE, RFECV, SelectKBest
@@ -54,7 +54,8 @@
5454

5555
dont_test = [Pipeline, FeatureUnion, GridSearchCV, SparseCoder,
5656
EllipticEnvelope, EllipticEnvelop, DictVectorizer, LabelBinarizer,
57-
LabelEncoder, TfidfTransformer, IsotonicRegression, OneHotEncoder]
57+
LabelEncoder, TfidfTransformer, IsotonicRegression, OneHotEncoder,
58+
RandomForestHasher]
5859
meta_estimators = [BaseEnsemble, OneVsOneClassifier, OutputCodeClassifier,
5960
OneVsRestClassifier, RFE, RFECV]
6061

0 commit comments

Comments
 (0)