Skip to content

Commit 195de6a

Browse files
hrjnTomDLT
authored andcommitted
[MRG+1] Fix multi-label issues in IsolationForest benchmark (scikit-learn#8638)
* Fixed a legacy multi-label issue and added minor refactoring and changes (mostly esthethic) Minor corrections after code review. Minor corrections after 2nd code review. Minor modif * rerun CI
1 parent 9e0e2d4 commit 195de6a

File tree

1 file changed

+62
-42
lines changed

1 file changed

+62
-42
lines changed

benchmarks/bench_isolation_forest.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,51 @@
22
==========================================
33
IsolationForest benchmark
44
==========================================
5-
65
A test of IsolationForest on classical anomaly detection datasets.
7-
86
"""
9-
print(__doc__)
107

118
from time import time
129
import numpy as np
1310
import matplotlib.pyplot as plt
11+
1412
from sklearn.ensemble import IsolationForest
1513
from sklearn.metrics import roc_curve, auc
1614
from sklearn.datasets import fetch_kddcup99, fetch_covtype, fetch_mldata
17-
from sklearn.preprocessing import LabelBinarizer
15+
from sklearn.preprocessing import MultiLabelBinarizer
1816
from sklearn.utils import shuffle as sh
1917

20-
np.random.seed(1)
18+
print(__doc__)
2119

22-
datasets = ['http', 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
2320

21+
def print_outlier_ratio(y):
22+
"""
23+
Helper function to show the distinct value count of element in the target.
24+
Useful indicator for the datasets used in bench_isolation_forest.py.
25+
"""
26+
uniq, cnt = np.unique(y, return_counts=True)
27+
print("----- Target count values: ")
28+
for u, c in zip(uniq, cnt):
29+
print("------ %s -> %d occurences" % (str(u), c))
30+
print("----- Outlier ratio: %.5f" % (np.min(cnt) / len(y)))
31+
32+
33+
np.random.seed(1)
2434
fig_roc, ax_roc = plt.subplots(1, 1, figsize=(8, 5))
2535

36+
# Set this to true for plotting score histograms for each dataset:
37+
with_decision_function_histograms = False
2638

39+
# Removed the shuttle dataset because as of 2017-03-23 mldata.org is down:
40+
# datasets = ['http', 'smtp', 'SA', 'SF', 'shuttle', 'forestcover']
41+
datasets = ['http', 'smtp', 'SA', 'SF', 'forestcover']
42+
43+
# Loop over all datasets for fitting and scoring the estimator:
2744
for dat in datasets:
28-
# loading and vectorization
29-
print('loading data')
30-
if dat in ['http', 'smtp', 'SA', 'SF']:
45+
46+
# Loading and vectorizing the data:
47+
print('====== %s ======' % dat)
48+
print('--- Fetching data...')
49+
if dat in ['http', 'smtp', 'SF', 'SA']:
3150
dataset = fetch_kddcup99(subset=dat, shuffle=True, percent10=True)
3251
X = dataset.data
3352
y = dataset.target
@@ -43,6 +62,7 @@
4362
X = X[s, :]
4463
y = y[s]
4564
y = (y != 1).astype(int)
65+
print('----- ')
4666

4767
if dat == 'forestcover':
4868
dataset = fetch_covtype(shuffle=True)
@@ -54,29 +74,29 @@
5474
X = X[s, :]
5575
y = y[s]
5676
y = (y != 2).astype(int)
77+
print_outlier_ratio(y)
5778

58-
print('vectorizing data')
79+
print('--- Vectorizing data...')
5980

6081
if dat == 'SF':
61-
lb = LabelBinarizer()
62-
lb.fit(X[:, 1])
63-
x1 = lb.transform(X[:, 1])
82+
lb = MultiLabelBinarizer()
83+
x1 = lb.fit_transform(X[:, 1])
6484
X = np.c_[X[:, :1], x1, X[:, 2:]]
65-
y = (y != 'normal.').astype(int)
85+
y = (y != b'normal.').astype(int)
86+
print_outlier_ratio(y)
6687

6788
if dat == 'SA':
68-
lb = LabelBinarizer()
69-
lb.fit(X[:, 1])
70-
x1 = lb.transform(X[:, 1])
71-
lb.fit(X[:, 2])
72-
x2 = lb.transform(X[:, 2])
73-
lb.fit(X[:, 3])
74-
x3 = lb.transform(X[:, 3])
89+
lb = MultiLabelBinarizer()
90+
x1 = lb.fit_transform(X[:, 1])
91+
x2 = lb.fit_transform(X[:, 2])
92+
x3 = lb.fit_transform(X[:, 3])
7593
X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]]
76-
y = (y != 'normal.').astype(int)
94+
y = (y != b'normal.').astype(int)
95+
print_outlier_ratio(y)
7796

78-
if dat == 'http' or dat == 'smtp':
79-
y = (y != 'normal.').astype(int)
97+
if dat in ('http', 'smtp'):
98+
y = (y != b'normal.').astype(int)
99+
print_outlier_ratio(y)
80100

81101
n_samples, n_features = X.shape
82102
n_samples_train = n_samples // 2
@@ -87,34 +107,34 @@
87107
y_train = y[:n_samples_train]
88108
y_test = y[n_samples_train:]
89109

90-
print('IsolationForest processing...')
110+
print('--- Fitting the IsolationForest estimator...')
91111
model = IsolationForest(n_jobs=-1)
92112
tstart = time()
93113
model.fit(X_train)
94114
fit_time = time() - tstart
95115
tstart = time()
96116

97-
scoring = - model.decision_function(X_test) # the lower, the more normal
98-
99-
# Show score histograms
100-
fig, ax = plt.subplots(3, sharex=True, sharey=True)
101-
bins = np.linspace(-0.5, 0.5, 200)
102-
ax[0].hist(scoring, bins, color='black')
103-
ax[0].set_title('decision function for %s dataset' % dat)
104-
ax[0].legend(loc="lower right")
105-
ax[1].hist(scoring[y_test == 0], bins, color='b',
106-
label='normal data')
107-
ax[1].legend(loc="lower right")
108-
ax[2].hist(scoring[y_test == 1], bins, color='r',
109-
label='outliers')
110-
ax[2].legend(loc="lower right")
117+
scoring = - model.decision_function(X_test) # the lower, the more abnormal
118+
119+
print("--- Preparing the plot elements...")
120+
if with_decision_function_histograms:
121+
fig, ax = plt.subplots(3, sharex=True, sharey=True)
122+
bins = np.linspace(-0.5, 0.5, 200)
123+
ax[0].hist(scoring, bins, color='black')
124+
ax[0].set_title('Decision function for %s dataset' % dat)
125+
ax[1].hist(scoring[y_test == 0], bins, color='b', label='normal data')
126+
ax[1].legend(loc="lower right")
127+
ax[2].hist(scoring[y_test == 1], bins, color='r', label='outliers')
128+
ax[2].legend(loc="lower right")
111129

112130
# Show ROC Curves
113131
predict_time = time() - tstart
114132
fpr, tpr, thresholds = roc_curve(y_test, scoring)
115-
AUC = auc(fpr, tpr)
116-
label = ('%s (area: %0.3f, train-time: %0.2fs, '
117-
'test-time: %0.2fs)' % (dat, AUC, fit_time, predict_time))
133+
auc_score = auc(fpr, tpr)
134+
label = ('%s (AUC: %0.3f, train_time= %0.2fs, '
135+
'test_time= %0.2fs)' % (dat, auc_score, fit_time, predict_time))
136+
# Print AUC score and train/test time:
137+
print(label)
118138
ax_roc.plot(fpr, tpr, lw=1, label=label)
119139

120140

0 commit comments

Comments
 (0)