图像分类快速入门:原理与代码

本文是图像分类的快速入门指南,介绍了图像分类在安防监控、交通、医疗等领域的应用,细分任务包括二分类、多分类和多标签分类,并通过深度学习中的卷积神经网络(CNN)进行实现,以CIFAR10数据集为例展示训练流程。同时,讨论了图像分类的损失函数,包括二分类、多分类和多标签分类的损失计算。

原创:悬鱼铭

图像分类是人工智能中重要的基础任务,也是目标检测、图像分割、目标跟踪等视觉进阶任务的基础,是人工智能从业者必须掌握的知识点。本文通过以下几点阐述:

  • 图像分类有哪些落地场景?

  • 图像分类有哪些细分任务?

  • 图像分类如何实现?(有代码注释)

  • 图像分类的损失函数

    全文总共3千字左右,阅读时间10分钟!

一、图像分类有哪些落地场景?

图像分类是计算机视觉领域的基础任务,也是应用比较广泛的任务。图像分类用来解决“是什么”的问题,如给定一张图片,用标签描述图片的主要内容,下图中有三个企鹅,标签为企鹅。

来自花瓣网【1】,侵权联系删除

基于图像分类的智能应用也日渐成熟,如在安防监控、智慧交通、医疗影像诊断等。安防监控中,上下班的人脸打卡机器,人们走到机器前,相机采集到人脸图像,机器里的算法判断人脸图像与人脸图像库里的某个人脸最相似,就判断出人脸图像属于那个人。

来自花瓣网【2】,侵权联系删除

交通领域中的交通标识识别,可以辅助驾驶;手机拍照识别花的种类,智能整理相册;电商平台里的输入标签,返回含有标签的商品等。

二、图像分类有哪些细分任务?

图像分类根据标签的不同,大致可分为二分类任务、多分类任务、多标签分类任务。

应对复杂的生活场景,图像分类会有更加细致的任务。在安防监控中,在监控系统中寻找逃犯,从全国各地摄像头采集的图像,面对庞大的图像库,只要找到是逃犯的图像,其他图像都不是逃犯。这里就对图像进行二分类,是逃犯与不是逃犯,这是二分类的任务

在智能整理手机相册时,每张图像会设置一个标签,同一个标签的图像会放在一个文件夹下,并且以标签来命名,整理完之后会有多个文件夹,这里标签有多个,是多分类的任务

在短视频推荐中,真实生活场景,图像包含丰富的内容,每张图像不在局限单标签,而是把图像包含的主要内容展示,往多标签发展。对图像进行多标签分类,提供丰富多样的标签,可以促进个性化推荐。

三、图像分类如何实现?

经典的图像分类一般包括预处理、特征提取、分类器,其中特征提取一般通过手工精心设计。研究者会花费大量的精力去探索如何提取到鲁棒性较好的图像特征。深度学习中的卷积神经网络(Convolution Neural Network, CNN)可在大量数据中自动学习到数据的层次化表示。近年来,得益于强大的计算机、更大的数据集,CNN提取图像特征成为主流方法。

基于深度学习的图像分类,将传统的图像分类流程(预处理、特征提取、分类器),全部体现在各种层的组合,有卷积层、池化层、全连接层,图像分类流程如图1所示。训练过程中主要是求解模型的参数,一个输入图片经过多个卷积、池化,它们提取图像特征图,图像特征图拉伸为一维特征向量,连接全连接层,将特征图映射到标签(类别),可知输入图片属于每个标签的概率值。选取概率值最大的标签作为预测的结果。根据推理的结果与图片的真实标签的差距,即为损失函数,再通过梯度下降的方法求解模型参数。参数确定之后,模型就确定了,可以推理测试集中新的图片。

图像多分类任务

图像分类中常用的数据集有CIFAR10【3】,有6万张图像,其中5万张训练集图像,1万张测试集图像,图像大小为$32 \times 32 $ 的彩色图像。每张图片一个标签,数据集总共10个标签,有鸟(bird)、太阳(sunset)、狗(dog)、猫(cat)等。

下面进入图像多类别分类实践,以VGG16网络为基准模型,在Pytorch中展示图像多类别分类的训练流程,并且将Pytorch中的数据流展示在下图中。

接下来是使用VGG16,进行图像分类,数据集是CIFAR10。

import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

# 预处理的设置
# 图片转化为 backbone网络规定的图片大小
# 归一化是减去均值,除以方差
# 把 numpy array 转化为 tensor 的格式
image_size = 224
r_mean, g_mean, b_mean = 0.4914, 0.4822, 0.4465
r_std, g_std, b_std = 0.247, 0.243, 0.261

my_tf = transforms.Compose([
    transforms.Resize((image_size,image_size)),
    transforms.ToTensor(),
    transforms.Normalize([r_mean,g_mean,b_mean], [r_std,g_std,b_std])])

# 读取数据集 CIFAR-10 的图,有10个标签,5万张图片,进行预处理。
train_dataset= torchvision.datasets.CIFAR10(root='./',train=True,transform=my_tf,download=True)
test_dataset= torchvision.datasets.CIFAR10(root='./',train=False,transform=my_tf,download=True)

# 调用预训练模型vgg16
my_vgg = torchvision.models.vgg16(pretrained=True)
# 固定网络框架全连接层之前的参数
for param in my_vgg.parameters():
    param.requires_grad=False
# 将vgg最后一层输出的类别数,改为cifar-10的类别数(10)
class_size = 10
in_f = my_vgg.classifier[6].in_features
my_vgg.classifier[6] = nn.Linear(in_f,class_size)

# 超参数设置
learn_rate = 0.001
num_epoches = 10
batch_size = 32
momentum = 0.9
# 多分类损失函数,使用默认值
criterion = nn.CrossEntropyLoss()  
# 梯度下降,求解模型最后一层参数
optimizer = optim.SGD(my_vgg.classifier[6].parameters(),lr=
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值