用于分类的神经网络写法(基于iris数据集)

本文适合神经网络初学者,通过鸢尾花数据集介绍三种神经网络的实现方法:KNN、sklearn的MLPClassifier和Keras、Tensorflow。通过比较训练集和测试集的准确性,展示了每种方法的效果。KNN得分0.958(训练)和1.0(测试),MLPClassifier得分0.975(训练)和1.0(测试),Keras得分0.983(训练)和1.0(测试),Tensorflow得分0.975(训练)和1.0(测试)。

前言

本文本适合于学习神经网络的小白,高手或大神请绕行。
iris(鸢尾花分类)数据集是一个机器学习中的经典数据集,类别分为Setosa,Versicolour,Virginica三种;输入的特征有四个(即四维数据),分别为花萼长度,花萼宽度,花瓣长度,花瓣宽度。
从我们的通识来说,知道同一类的种类会有相近的花形状,不同一类的种类的花会有差别。
因此,可以通过上述四个特征来进行区分。
本博客程序代码iris_data_training.ipynb

数据展示以及准备

环境:python3.7+keras2.3.1+tensorflow1.15.0+sklearn0.22.1+matplotlib3.1.1
三方库:mglearn0.1.7
为什么使用mglearn呢?因为它已经封装好了matplotlib.pyplot的scatter方法,只输入X, y就好了,学习时真的没有必要重复造轮子。

  1. 首先导入通用库
from sklearn.datasets import load_iris
import numpy as np
import matplotlib.pyplot as plt
import mglearn
  1. 加载数据
# get data
iris = load_iris()
X, y = iris.data, iris.target
print('Shape', X.shape, y.shape)	# 输出结果Shape (150, 4) (150,)
RANDOM_STATE = 2	# 为了使用相同的数据好比对,选择相同的随机种子
  1. 由于数据输入有四个维度,而使用简单的展示数据分布要用二维,所以两两地组合展示出来。总共有6种搭配。
# 选择第0维第1维
mglearn.discrete_scatter(X[:, 0], X[:, 1], y)

在这里插入图片描述

# 选择第0维第2维
mglearn.discrete_scatter(X[:, 0], X[:, 2], y)

在这里插入图片描述

# 选择第0维第3维
mglearn.discrete_scatter(X[:, 0], X[:, 3], y)

在这里插入图片描述

# 选择第1维第2维
mglearn.discrete_scatter(X[:, 1], X[:, 2], y)

在这里插入图片描述

# 选择第1维第3维
mglearn.discrete_scatter(X[:, 1], X[:, 3], y)

在这里插入图片描述

# 选择第2维第3维
mglearn.discrete_scatter(X[:, 2], X[:, 3], y)

在这里插入图片描述
所有的组合看完,发现第2维和第3维画出来比较容易看出来分类情况,所以第2维和第3维的二维数据展示仅仅查看看预测的好坏程度。
在这里,我们写一个展示真实结果与预测结果的比对的方法。

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
def eval_on_show(X_train, y_train, y_train_pred, X_test, y_test, y_test_pred):
    f = plt.figure('train set', figsize=(10, 5))
    plt.subplot(121)
    plt.title('train set ture')
    mglearn.discrete_scatter(X_train[:, 2], X_train
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值