Skip to content

Commit 973edd5

Browse files
committed
ENH add RandomHashingForest estimator.
1 parent 541f01d commit 973edd5

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

sklearn/ensemble/forest.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
4848
from ..tree._tree import DTYPE, DOUBLE
4949
from ..utils import array2d, check_random_state, check_arrays
5050
from ..metrics import r2_score
51+
from ..preprocessing import OneHotEncoder
5152

5253
from .base import BaseEnsemble
5354

@@ -1156,3 +1157,40 @@ def __init__(self, n_estimators=10,
11561157
self.min_samples_leaf = min_samples_leaf
11571158
self.min_density = min_density
11581159
self.max_features = max_features
1160+
1161+
1162+
class RandomHashingForest(ExtraTreesClassifier):
1163+
def __init__(self, n_estimators=10,
1164+
max_depth=5,
1165+
min_samples_split=1,
1166+
min_samples_leaf=1,
1167+
min_density=0.1,
1168+
n_jobs=1,
1169+
random_state=None,
1170+
verbose=0):
1171+
super(RandomHashingForest, self).__init__(
1172+
n_estimators=n_estimators,
1173+
max_depth=max_depth,
1174+
min_samples_split=min_samples_split,
1175+
min_samples_leaf=min_samples_leaf,
1176+
min_density=min_density,
1177+
max_features=1,
1178+
bootstrap=False,
1179+
compute_importances=False,
1180+
oob_score=False,
1181+
n_jobs=n_jobs,
1182+
random_state=random_state,
1183+
verbose=verbose)
1184+
1185+
def fit(self, X, y=None):
1186+
self.fit_transform(X, y)
1187+
return self
1188+
1189+
def fit_transform(self, X, y=None):
1190+
y = np.arange(len(X))
1191+
super(RandomHashingForest, self).fit(X, y)
1192+
self.one_hot_encoder_ = OneHotEncoder()
1193+
return self.one_hot_encoder_.fit_transform(self.apply(X))
1194+
1195+
def transform(self, X):
1196+
return self.one_hot_encoder_.transform(self.apply(X))

0 commit comments

Comments
 (0)