Skip to content

Commit 64b01c8

Browse files
committed
Merge pull request #1 from anujonthemove/anujonthemove-patch-1
updated file parser, n_folds=5, n_jobs=-1
2 parents e2acc0b + 4218ef5 commit 64b01c8

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

crossValidate.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,32 @@
22
from sklearn import cross_validation
33
import logloss
44
import numpy as np
5+
import pandas as pd
6+
from sklearn.preprocessing import Imputer
57

68
def main():
79
#read in data, parse into training and target sets
8-
dataset = np.genfromtxt(open('Data/train.csv','r'), delimiter=',', dtype='f8')[1:]
9-
target = np.array([x[0] for x in dataset])
10-
train = np.array([x[1:] for x in dataset])
10+
dataset = pd.read_csv('Data/train.csv')
11+
target = dataset.Activity.values
12+
train = dataset.drop('Activity', axis=1).values
13+
imp = Imputer(missing_values = 'NaN',strategy='mean',axis=0)
14+
new_train_data = imp.fit_transform(train)
1115

1216
#In this case we'll use a random forest, but this could be any classifier
13-
cfr = RandomForestClassifier(n_estimators=100)
17+
cfr = RandomForestClassifier(n_estimators=100, n_jobs=-1)
1418

1519
#Simple K-Fold cross validation. 5 folds.
16-
cv = cross_validation.KFold(len(train), k=5, indices=False)
20+
cv = cross_validation.KFold(len(new_train_data), n_folds=5, indices=False)
1721

1822
#iterate through the training and test cross validation segments and
1923
#run the classifier on each one, aggregating the results into a list
2024
results = []
2125
for traincv, testcv in cv:
22-
probas = cfr.fit(train[traincv], target[traincv]).predict_proba(train[testcv])
26+
probas = cfr.fit(new_train_data[traincv], target[traincv]).predict_proba(new_train_data[testcv])
2327
results.append( logloss.llfun(target[testcv], [x[1] for x in probas]) )
2428

2529
#print out the mean of the cross-validated results
2630
print "Results: " + str( np.array(results).mean() )
2731

2832
if __name__=="__main__":
29-
main()
33+
main()

0 commit comments

Comments
 (0)