33import pandas as pd
44
55def plot_decision_boundaries (X , y , model_class , ** model_params ):
6- """Function to plot the decision boundaries of a classification model.
6+ """
7+ Function to plot the decision boundaries of a classification model.
78 This uses just the first two columns of the data for fitting
89 the model as we need to find the predicted value for every point in
910 scatter plot.
@@ -27,8 +28,11 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
2728 y = np .array (y ).flatten ()
2829 except :
2930 print ("Coercing input data to NumPy arrays failed" )
31+ # Reduces to the first two columns of data
3032 reduced_data = X [:, :2 ]
33+ # Instantiate the model object
3134 model = model_class (** model_params )
35+ # Fits the model with the reduced data
3236 model .fit (reduced_data , y )
3337
3438 # Step size of the mesh. Decrease to increase the quality of the VQ.
@@ -37,6 +41,7 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
3741 # Plot the decision boundary. For that, we will assign a color to each
3842 x_min , x_max = reduced_data [:, 0 ].min () - 1 , reduced_data [:, 0 ].max () + 1
3943 y_min , y_max = reduced_data [:, 1 ].min () - 1 , reduced_data [:, 1 ].max () + 1
44+ # Meshgrid creation
4045 xx , yy = np .meshgrid (np .arange (x_min , x_max , h ), np .arange (y_min , y_max , h ))
4146
4247 # Obtain labels for each point in mesh using the model.
@@ -47,8 +52,10 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
4752 xx , yy = np .meshgrid (np .arange (x_min , x_max , 0.1 ),
4853 np .arange (y_min , y_max , 0.1 ))
4954
55+ # Predictions to obtain the classification results
5056 Z = model .predict (np .c_ [xx .ravel (), yy .ravel ()]).reshape (xx .shape )
5157
58+ # Plotting
5259 plt .contourf (xx , yy , Z , alpha = 0.4 )
5360 plt .scatter (X [:, 0 ], X [:, 1 ], c = y , alpha = 0.8 )
5461 plt .xlabel ("Feature-1" ,fontsize = 15 )
0 commit comments