|  | 
|  | 1 | +""" | 
|  | 2 | +========================================== | 
|  | 3 | +IsolationForest benchmark | 
|  | 4 | +========================================== | 
|  | 5 | +
 | 
|  | 6 | +A test of IsolationForest on classical anomaly detection datasets. | 
|  | 7 | +
 | 
|  | 8 | +""" | 
|  | 9 | +print(__doc__) | 
|  | 10 | + | 
|  | 11 | +from time import time | 
|  | 12 | +import numpy as np | 
|  | 13 | +import matplotlib.pyplot as plt | 
|  | 14 | +from sklearn.ensemble import IsolationForest | 
|  | 15 | +from sklearn.metrics import roc_curve, auc | 
|  | 16 | +from sklearn.datasets import fetch_kddcup99, fetch_covtype, fetch_mldata | 
|  | 17 | +from sklearn.preprocessing import LabelBinarizer | 
|  | 18 | +from sklearn.utils import shuffle as sh | 
|  | 19 | + | 
|  | 20 | +np.random.seed(1) | 
|  | 21 | + | 
|  | 22 | + | 
|  | 23 | +datasets = ['http']#, 'smtp', 'SA', 'SF', 'shuttle', 'forestcover'] | 
|  | 24 | + | 
|  | 25 | +for dat in datasets: | 
|  | 26 | +    # loading and vectorization | 
|  | 27 | +    print('loading data') | 
|  | 28 | +    if dat in ['http', 'smtp', 'SA', 'SF']: | 
|  | 29 | +        dataset = fetch_kddcup99(subset=dat, shuffle=True, percent10=True) | 
|  | 30 | +        X = dataset.data | 
|  | 31 | +        y = dataset.target | 
|  | 32 | + | 
|  | 33 | +    if dat == 'shuttle': | 
|  | 34 | +        dataset = fetch_mldata('shuttle') | 
|  | 35 | +        X = dataset.data | 
|  | 36 | +        y = dataset.target | 
|  | 37 | +        sh(X, y) | 
|  | 38 | +        # we remove data with label 4 | 
|  | 39 | +        # normal data are then those of class 1 | 
|  | 40 | +        s = (y != 4) | 
|  | 41 | +        X = X[s, :] | 
|  | 42 | +        y = y[s] | 
|  | 43 | +        y = (y != 1).astype(int) | 
|  | 44 | + | 
|  | 45 | +    if dat == 'forestcover': | 
|  | 46 | +        dataset = fetch_covtype(shuffle=True) | 
|  | 47 | +        X = dataset.data | 
|  | 48 | +        y = dataset.target | 
|  | 49 | +        # normal data are those with attribute 2 | 
|  | 50 | +        # abnormal those with attribute 4 | 
|  | 51 | +        s = (y == 2) + (y == 4) | 
|  | 52 | +        X = X[s, :] | 
|  | 53 | +        y = y[s] | 
|  | 54 | +        y = (y != 2).astype(int) | 
|  | 55 | + | 
|  | 56 | +    print('vectorizing data') | 
|  | 57 | + | 
|  | 58 | +    if dat == 'SF': | 
|  | 59 | +        lb = LabelBinarizer() | 
|  | 60 | +        lb.fit(X[:, 1]) | 
|  | 61 | +        x1 = lb.transform(X[:, 1]) | 
|  | 62 | +        X = np.c_[X[:, :1], x1, X[:, 2:]] | 
|  | 63 | +        y = (y != 'normal.').astype(int) | 
|  | 64 | + | 
|  | 65 | +    if dat == 'SA': | 
|  | 66 | +        lb = LabelBinarizer() | 
|  | 67 | +        lb.fit(X[:, 1]) | 
|  | 68 | +        x1 = lb.transform(X[:, 1]) | 
|  | 69 | +        lb.fit(X[:, 2]) | 
|  | 70 | +        x2 = lb.transform(X[:, 2]) | 
|  | 71 | +        lb.fit(X[:, 3]) | 
|  | 72 | +        x3 = lb.transform(X[:, 3]) | 
|  | 73 | +        X = np.c_[X[:, :1], x1, x2, x3, X[:, 4:]] | 
|  | 74 | +        y = (y != 'normal.').astype(int) | 
|  | 75 | + | 
|  | 76 | +    if dat == 'http' or dat == 'smtp': | 
|  | 77 | +        y = (y != 'normal.').astype(int) | 
|  | 78 | + | 
|  | 79 | +    n_samples, n_features = np.shape(X) | 
|  | 80 | +    n_samples_train = n_samples // 2 | 
|  | 81 | +    n_samples_test = n_samples - n_samples_train | 
|  | 82 | + | 
|  | 83 | +    X = X.astype(float) | 
|  | 84 | +    X_train = X[:n_samples_train, :] | 
|  | 85 | +    X_test = X[n_samples_train:, :] | 
|  | 86 | +    y_train = y[:n_samples_train] | 
|  | 87 | +    y_test = y[n_samples_train:] | 
|  | 88 | + | 
|  | 89 | +    print('IsolationForest processing...') | 
|  | 90 | +    model = IsolationForest(bootstrap=True, n_jobs=-1) | 
|  | 91 | +    tstart = time() | 
|  | 92 | +    model.fit(X_train) | 
|  | 93 | +    fit_time = time() - tstart | 
|  | 94 | +    tstart = time() | 
|  | 95 | + | 
|  | 96 | +    scoring = model.predict(X_test)  # the lower, the more normal | 
|  | 97 | +    predict_time = time() - tstart | 
|  | 98 | +    fpr, tpr, thresholds = roc_curve(y_test, scoring) | 
|  | 99 | +    AUC = auc(fpr, tpr) | 
|  | 100 | +    plt.plot(fpr, tpr, lw=1, label='ROC for %s (area = %0.3f, train-time: %0.2fs, test-time: %0.2fs)' % (dat, AUC, fit_time, predict_time)) | 
|  | 101 | + | 
|  | 102 | +plt.xlim([-0.05, 1.05]) | 
|  | 103 | +plt.ylim([-0.05, 1.05]) | 
|  | 104 | +plt.xlabel('False Positive Rate') | 
|  | 105 | +plt.ylabel('True Positive Rate') | 
|  | 106 | +plt.title('Receiver operating characteristic') | 
|  | 107 | +plt.legend(loc="lower right") | 
|  | 108 | +plt.show() | 
0 commit comments