1+ import sklearn .datasets as datasets
2+ import numpy as np
3+ from sklearn .linear_model import LogisticRegression as LR
4+ class LogisticRegression ():
5+ def __init__ (self ,alpha = 0.01 ,epochs = 3 ):
6+ self .W = None
7+ self .b = None
8+ self .alpha = alpha
9+ self .epochs = epochs
10+ def fit (self ,X ,y ):
11+ # 设定种子
12+ np .random .seed (10 )
13+ self .W = np .random .normal (size = (X .shape [1 ]))
14+ self .b = 0
15+ for epoch in range (self .epochs ):
16+ if epoch % 50 == 0 :
17+ print ("epoch" ,epoch )
18+ w_derivate = np .zeros_like (self .W )
19+ b_derivate = 0
20+ for i in range (len (y )):
21+ w_derivate += (y [i ] - 1 / (1 + np .exp (- np .dot (X [i ],self .W .T )- self .b )))* X [i ]
22+ b_derivate += (y [i ] - 1 / (1 + np .exp (- np .dot (X [i ],self .W .T )- self .b )))
23+ self .W = self .W + self .alpha * np .mean (w_derivate ,axis = 0 )
24+ self .b = self .b + self .alpha * np .mean (b_derivate )
25+ return self
26+ def predict (self ,X ):
27+ p_1 = 1 / (1 + np .exp (- np .dot (X ,self .W ) - self .b ))
28+ return np .where (p_1 > 0.5 , 1 , 0 )
29+ def accuracy (pred , true ):
30+ count = 0
31+ for i in range (len (pred )):
32+ if (pred [i ] == true [i ]):
33+ count += 1
34+ return count / len (pred )
35+ def normalize (x ):
36+ return (x - np .min (x ))/ (np .max (x ) - np .min (x ))
37+
38+ if __name__ == "__main__" :
39+ # input datasets
40+ digits = datasets .load_breast_cancer ()
41+ X = digits .data
42+ y = digits .target
43+ # 归一化
44+ X_norm = normalize (X )
45+ X_train = X_norm [:int (len (X_norm )* 0.8 )]
46+ X_test = X_norm [int (len (X_norm )* 0.8 ):]
47+ y_train = y [:int (len (X_norm )* 0.8 )]
48+ y_test = y [int (len (X_norm )* 0.8 ):]
49+ # model 1
50+ lr = LogisticRegression (epochs = 500 ,alpha = 0.03 )
51+ lr .fit (X_train ,y_train )
52+ y_pred = lr .predict (X_test )
53+ # 评估准确率
54+ acc = accuracy (y_pred , y_test )
55+ print ("acc" , acc )
56+ # model 2
57+ clf_lr = LR ()
58+ clf_lr .fit (X_train , y_train )
59+ y_pred2 = clf_lr .predict (X_test )
60+ print ("acc2" , accuracy (y_pred2 , y_test ))
0 commit comments