实战加载本地MNIST数据集(GZ格式)

该文介绍了如何下载MNIST数据集,并使用Python的gzip和numpy库加载和解析数据。通过加载函数,分别读取训练和测试标签及图像文件,将二进制数据转换为NumPy数组,调整形状以便后续处理。最后,展示了读取数据后的数组形状。

开发板推荐:天空星STM32F407VET6开发板

超高性价比 STM32主控 | 超高主频 | 一板兼容百芯 | 比赛神器 | 沉金彩色丝印

1、MNIST数据集下载( 提取码: MN4S)

下载

2、加载数据

import os
import gzip
import numpy as np
 
#加载数据
def load_data(data_file):
    files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
    paths = []
    for fileName in files:
        paths.append(os.path.join(data_file, fileName))
        
    # 读取每个文件夹的数据    
    with gzip.open(paths[0], 'rb') as train_labels_path:
        train_labels = np.frombuffer(train_labels_path.read(), np.uint8, offset=8)
      
    with gzip.open(paths[1], 'rb') as train_images_path:
        train_images = np.frombuffer(train_images_path.read(), np.uint8, offset=16).reshape(len(train_labels), 784)
       
    with gzip.open(paths[2], 'rb') as test_labels_path:
        test_labels = np.frombuffer(test_labels_path.read(), np.uint8, offset=8)
        
    with gzip.open(paths[3], 'rb') as test_images_path:
        test_images = np.frombuffer(test_images_path.read(), np.uint8, offset=16).reshape(len(test_labels), 784)
        
    return train_labels,train_images,test_labels,test_images

train_labels,train_images,test_labels,test_images = load_data('MNIST/')

#打印形状
print(train_labels.shape)
print(train_images.shape)
print(test_labels.shape)
print(test_images.shape)

打印结果:

(60000,)
(60000, 784)
(10000,)
(10000, 784)

3、代码解释

  • with gzip.open(paths[0], 'rb') as train_labels_path::使用gzip.open()函数打开名为train_labels_path的压缩文件,以二进制模式('rb')读取。

  • train_labels = np.frombuffer(train_labels_path.read(), np.uint8, offset=8):从压缩文件中读取数据,将其转换为NumPy数组。np.frombuffer()函数用于从缓冲区中读取数据,并将其转换为指定类型的NumPy数组。这里的offset=8表示从第8个字节开始读取数据,因为通常情况下,训练标签文件是以8字节的整数对(一个标签值和一个标签类别)的形式存储的。

  • with gzip.open(paths[1], 'rb') as train_images_path::使用gzip.open()函数打开名为train_images_path的压缩文件,以二进制模式('rb')读取。

  • train_images = np.frombuffer(train_images_path.read(), np.uint8, offset=16).reshape(len(train_labels), 784):从压缩文件中读取数据,将其转换为NumPy数组。同样地,这里使用np.frombuffer()函数从缓冲区中读取数据,并将其转换为指定类型的NumPy数组。由于训练图像文件是以16字节的整数对(一个图像值和一个图像尺寸)的形式存储的,所以需要将读取到的数据重新排列为784x1x28x28的形状,其中784是每个图像的像素数量,28是每个图像的高度和宽度。这里的len(train_labels)表示训练标签的数量,因为每个训练图像都有一个对应的标签。

开发板推荐:天空星STM32F407VET6开发板

超高性价比 STM32主控 | 超高主频 | 一板兼容百芯 | 比赛神器 | 沉金彩色丝印

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

缘起性空、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值