1+ import numpy as np
2+ import matplotlib .pyplot as plt
3+ import pandas as pd
4+
5+ def plot_decision_boundaries (X , y , model_class , ** model_params ):
6+ """Function to plot the decision boundaries of a classification model.
7+ This uses just the first two columns of the data for fitting
8+ the model as we need to find the predicted value for every point in
9+ scatter plot.
10+
11+ Arguments:
12+ X: Feature data as a NumPy-type array.
13+ y: Label data as a NumPy-type array.
14+ model_class: A Scikit-learn ML estimator class
15+ e.g. GaussianNB (imported from sklearn.naive_bayes) or
16+ LogisticRegression (imported from sklearn.linear_model)
17+ **model_params: Model parameters to be passed on to the ML estimator
18+
19+ Typical code example:
20+ plt.figure()
21+ plt.title("KNN decision boundary with neighbros: 5",fontsize=16)
22+ plot_decision_boundaries(X_train,y_train,KNeighborsClassifier,n_neighbors=5)
23+ plt.show()
24+ """
25+ try :
26+ X = np .array (X )
27+ y = np .array (y ).flatten ()
28+ except :
29+ print ("Coercing input data to NumPy arrays failed" )
30+ reduced_data = X [:, :2 ]
31+ model = model_class (** model_params )
32+ model .fit (reduced_data , y )
33+
34+ # Step size of the mesh. Decrease to increase the quality of the VQ.
35+ h = .02 # point in the mesh [x_min, m_max]x[y_min, y_max].
36+
37+ # Plot the decision boundary. For that, we will assign a color to each
38+ x_min , x_max = reduced_data [:, 0 ].min () - 1 , reduced_data [:, 0 ].max () + 1
39+ y_min , y_max = reduced_data [:, 1 ].min () - 1 , reduced_data [:, 1 ].max () + 1
40+ xx , yy = np .meshgrid (np .arange (x_min , x_max , h ), np .arange (y_min , y_max , h ))
41+
42+ # Obtain labels for each point in mesh using the model.
43+ Z = model .predict (np .c_ [xx .ravel (), yy .ravel ()])
44+
45+ x_min , x_max = X [:, 0 ].min () - 1 , X [:, 0 ].max () + 1
46+ y_min , y_max = X [:, 1 ].min () - 1 , X [:, 1 ].max () + 1
47+ xx , yy = np .meshgrid (np .arange (x_min , x_max , 0.1 ),
48+ np .arange (y_min , y_max , 0.1 ))
49+
50+ Z = model .predict (np .c_ [xx .ravel (), yy .ravel ()]).reshape (xx .shape )
51+
52+ plt .contourf (xx , yy , Z , alpha = 0.4 )
53+ plt .scatter (X [:, 0 ], X [:, 1 ], c = y , alpha = 0.8 )
54+ plt .xlabel ("Feature-1" ,fontsize = 15 )
55+ plt .ylabel ("Feature-2" ,fontsize = 15 )
56+ plt .xticks (fontsize = 14 )
57+ plt .yticks (fontsize = 14 )
58+ return plt
0 commit comments