1、MNIST数据集下载( 提取码: MN4S)
2、加载数据
import os
import gzip
import numpy as np
#加载数据
def load_data(data_file):
files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']
paths = []
for fileName in files:
paths.append(os.path.join(data_file, fileName))
# 读取每个文件夹的数据
with gzip.open(paths[0], 'rb') as train_labels_path:
train_labels = np.frombuffer(train_labels_path.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as train_images_path:
train_images = np.frombuffer(train_images_path.read(), np.uint8, offset=16).reshape(len(train_labels), 784)
with gzip.open(paths[2], 'rb') as test_labels_path:
test_labels = np.frombuffer(test_labels_path.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as test_images_path:
test_images = np.frombuffer(test_images_path.read(), np.uint8, offset=16).reshape(len(test_labels), 784)
return train_labels,train_images,test_labels,test_images
train_labels,train_images,test_labels,test_images = load_data('MNIST/')
#打印形状
print(train_labels.shape)
print(train_images.shape)
print(test_labels.shape)
print(test_images.shape)
打印结果:
(60000,)
(60000, 784)
(10000,)
(10000, 784)
3、代码解释
-
with gzip.open(paths[0], 'rb') as train_labels_path::使用gzip.open()函数打开名为train_labels_path的压缩文件,以二进制模式('rb')读取。 -
train_labels = np.frombuffer(train_labels_path.read(), np.uint8, offset=8):从压缩文件中读取数据,将其转换为NumPy数组。np.frombuffer()函数用于从缓冲区中读取数据,并将其转换为指定类型的NumPy数组。这里的offset=8表示从第8个字节开始读取数据,因为通常情况下,训练标签文件是以8字节的整数对(一个标签值和一个标签类别)的形式存储的。 -
with gzip.open(paths[1], 'rb') as train_images_path::使用gzip.open()函数打开名为train_images_path的压缩文件,以二进制模式('rb')读取。 -
train_images = np.frombuffer(train_images_path.read(), np.uint8, offset=16).reshape(len(train_labels), 784):从压缩文件中读取数据,将其转换为NumPy数组。同样地,这里使用np.frombuffer()函数从缓冲区中读取数据,并将其转换为指定类型的NumPy数组。由于训练图像文件是以16字节的整数对(一个图像值和一个图像尺寸)的形式存储的,所以需要将读取到的数据重新排列为784x1x28x28的形状,其中784是每个图像的像素数量,28是每个图像的高度和宽度。这里的len(train_labels)表示训练标签的数量,因为每个训练图像都有一个对应的标签。
该文介绍了如何下载MNIST数据集,并使用Python的gzip和numpy库加载和解析数据。通过加载函数,分别读取训练和测试标签及图像文件,将二进制数据转换为NumPy数组,调整形状以便后续处理。最后,展示了读取数据后的数组形状。
4231

被折叠的 条评论
为什么被折叠?



