1. Dataset & DataLoader 核心定位
表格
| 组件 | 本质定位 | 核心价值 |
|---|---|---|
Dataset | 数据 “生产者” | 标准化单样本的读取、预处理流程,让数据来源 / 格式对下游透明(无论图片 / 文本 / 表格,都返回统一格式) |
DataLoader | 数据 “调度器” | 解决批量加载、并行读取、数据打乱等工程问题,提升训练效率,屏蔽底层数据读取的性能瓶颈 |
2. 自定义 Dataset 的核心规范
- 必须继承
torch.utils.data.Dataset; - 必须实现
__len__():返回数据集总样本数; - 必须实现
__getitem__(idx):根据索引返回单个样本(通常是(数据, 标签)); - 预处理逻辑建议通过
transform参数传入(解耦数据读取和预处理)。
3. DataLoader 的核心参数与作用
表格
| 参数名 | 作用 | 实战建议 |
|---|---|---|
batch_size | 每个批次的样本数 | 显卡显存足够时设大(如 64/128),不足时减小 |
shuffle | 是否打乱数据顺序 | 训练集设True,验证 / 测试集设False |
num_workers | 多进程加载数据的进程数 | 设为 CPU 核心数(如 4/8),Windows 需注意路径问题 |
pin_memory | 是否将数据加载到锁页内存(加速 GPU 传输) | GPU 训练时设True,CPU 训练设False |
drop_last | 是否丢弃最后一个不完整批次 | 训练时设True(避免批次大小不一致),评估时设False |
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import pandas as pd
import os
import random
# ==================== 案例1:自定义图像Dataset(分类任务) ====================
print("=== 1. 自定义图像Dataset + DataLoader ===")
# 步骤1:模拟图像数据集(创建示例文件夹和图片)
def create_dummy_image_dataset(root="./dummy_image_data"):
"""创建模拟图像数据集(分类:cat/dog)"""
# 定义类别和样本数
classes = ["cat", "dog"]
num_samples_per_class = 20
# 创建目录
for cls in classes:
cls_dir = os.path.join(root, "train", cls)
os.makedirs(cls_dir, exist_ok=True)
# 生成随机图像并保存
for cls in classes:
cls_dir = os.path.join(root, "train", cls)
for i in range(num_samples_per_class):
# 生成随机3通道图像(224×224)
img = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
img_pil = Image.fromarray(img)
img_pil.save(os.path.join(cls_dir, f"{cls}_{i}.jpg"))
print(f"模拟图像数据集创建完成,路径:{root}")
return root
# 创建模拟数据集
data_root = create_dummy_image_dataset()
# 步骤2:自定义图像Dataset类
class ImageClassificationDataset(Dataset):
def __init__(self, root, transform=None, train=True):
"""
自定义图像分类Dataset
:param root: 数据根目录
:param transform: 预处理变换(None/Compose)
:param train: 是否为训练集
"""
self.root = os.path.join(root, "train" if train else "val")
self.transform = transform
self.classes = os.listdir(self.root) # 获取类别列表(cat/dog)
self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} # 类别→标签映射
self.samples = self._load_samples() # 加载所有样本路径和标签
def _load_samples(self):
"""加载所有样本的路径和标签"""
samples = []
for cls in self.classes:
cls_dir = os.path.join(self.root, cls)
for img_name in os.listdir(cls_dir):
img_path = os.path.join(cls_dir, img_name)
label = self.class_to_idx[cls]
samples.append((img_path, label))
return samples
def __len__(self):
"""返回总样本数(必须实现)"""
return len(self.samples)
def __getitem__(self, idx):
"""根据索引返回单个样本(必须实现)"""
# 1. 读取原始数据
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB") # 读取为RGB图像
# 2. 预处理(解耦:transform由外部传入)
if self.transform is not None:
image = self.transform(image)
# 3. 返回标准化格式(张量+标签)
return image, label
# 步骤3:定义图像预处理变换
train_transform = transforms.Compose([
transforms.Resize((224, 224)), # 缩放
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转(数据增强)
transforms.ToTensor(), # 转为张量(0-1)
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 步骤4:初始化Dataset
train_dataset = ImageClassificationDataset(
root=data_root,
transform=train_transform,
train=True
)
val_dataset = ImageClassificationDataset(
root=data_root,
transform=val_transform,
train=True # 演示用,实际应拆分验证集
)
# 步骤5:初始化DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=8,
shuffle=True, # 训练集打乱
num_workers=0, # Windows建议设0,Linux/Mac可设为CPU核心数
pin_memory=True if torch.cuda.is_available() else False,
drop_last=True # 丢弃最后不完整批次
)
val_loader = DataLoader(
val_dataset,
batch_size=8,
shuffle=False, # 验证集不打乱
num_workers=0,
pin_memory=True if torch.cuda.is_available() else False,
drop_last=False
)
# 测试DataLoader加载
print(f"\n训练集样本总数:{len(train_dataset)}")
print(f"训练集批次总数:{len(train_loader)}")
# 迭代DataLoader(模拟训练过程)
for batch_idx, (images, labels) in enumerate(train_loader):
print(f"批次{batch_idx+1}:")
print(f" 图像形状:{images.shape} (batch_size, channel, H, W)")
print(f" 标签形状:{labels.shape} (batch_size)")
print(f" 标签值:{labels.numpy()}")
if batch_idx >= 1: # 仅演示前2批次
break
# ==================== 案例2:自定义表格Dataset(回归任务) ====================
print("\n=== 2. 自定义表格Dataset + DataLoader ===")
# 步骤1:模拟表格数据集(房价预测)
def create_dummy_tabular_dataset(path="./dummy_house_data.csv"):
"""创建模拟房价预测数据集"""
np.random.seed(42)
# 特征:面积(㎡)、卧室数、楼层、建造年份
data = {
"area": np.random.uniform(50, 200, 100),
"bedrooms": np.random.randint(1, 5, 100),
"floor": np.random.randint(1, 30, 100),
"build_year": np.random.randint(1990, 2020, 100),
"price": 100 + 2*np.random.uniform(50, 200, 100) + 5*np.random.randint(1,5,100) # 房价(万)
}
df = pd.DataFrame(data)
df.to_csv(path, index=False)
print(f"模拟表格数据集创建完成,路径:{path}")
return path
# 创建模拟数据集
tabular_path = create_dummy_tabular_dataset()
# 步骤2:自定义表格Dataset类
class TabularRegressionDataset(Dataset):
def __init__(self, csv_path, transform=None):
"""
自定义表格回归Dataset
:param csv_path: CSV文件路径
:param transform: 特征预处理变换(如归一化)
"""
self.df = pd.read_csv(csv_path)
self.features = self.df.drop("price", axis=1).values # 特征列
self.targets = self.df["price"].values # 目标列(房价)
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# 1. 读取单样本特征和标签
feature = self.features[idx].astype(np.float32)
target = self.targets[idx].astype(np.float32)
# 2. 特征预处理(如归一化)
if self.transform is not None:
feature = self.transform(feature)
# 3. 转为张量并返回
return torch.from_numpy(feature), torch.tensor(target)
# 步骤3:定义表格数据预处理(归一化)
class TabularNormalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, x):
"""实现__call__,让类实例可作为transform使用"""
return (x - self.mean) / self.std
# 计算特征均值和标准差(训练集统计)
df = pd.read_csv(tabular_path)
features = df.drop("price", axis=1).values.astype(np.float32)
feature_mean = np.mean(features, axis=0)
feature_std = np.std(features, axis=0)
# 初始化表格Dataset
tabular_dataset = TabularRegressionDataset(
csv_path=tabular_path,
transform=TabularNormalize(feature_mean, feature_std)
)
# 初始化表格DataLoader
tabular_loader = DataLoader(
tabular_dataset,
batch_size=10,
shuffle=True,
num_workers=0
)
# 测试表格DataLoader
print(f"\n表格数据集样本总数:{len(tabular_dataset)}")
print(f"表格数据集批次总数:{len(tabular_loader)}")
for batch_idx, (features, targets) in enumerate(tabular_loader):
print(f"批次{batch_idx+1}:")
print(f" 特征形状:{features.shape} (batch_size, num_features)")
print(f" 房价标签:{targets.numpy()[:5]}") # 打印前5个房价
if batch_idx >= 1:
break
# ==================== 进阶:Dataset与DataLoader核心原理验证 ====================
print("\n=== 3. 核心原理验证 ===")
# 验证1:Dataset的单样本读取
sample_idx = 0
image_sample, label_sample = train_dataset[sample_idx]
print(f"\nDataset单样本验证:")
print(f" 图像形状:{image_sample.shape}")
print(f" 标签值:{label_sample}")
# 验证2:DataLoader的批量拼接逻辑
# DataLoader会自动将Dataset返回的单样本拼接为批次(维度扩展)
first_batch_images, first_batch_labels = next(iter(train_loader))
print(f"\nDataLoader批量验证:")
print(f" 单样本形状:{image_sample.shape} → 批次形状:{first_batch_images.shape}")
print(f" 单标签形状:() → 批次形状:{first_batch_labels.shape}")
# 验证3:shuffle参数的作用
# 对比shuffle=True/False的标签顺序
print(f"\nshuffle参数验证:")
# shuffle=False的Loader
val_loader_shuffle_false = DataLoader(val_dataset, batch_size=8, shuffle=False)
# 取第一个批次的标签
labels_shuffle_false = next(iter(val_loader_shuffle_false))[1].numpy()
# shuffle=True的Loader
labels_shuffle_true = next(iter(train_loader))[1].numpy()
print(f" shuffle=False 标签顺序:{labels_shuffle_false}")
print(f" shuffle=True 标签顺序:{labels_shuffle_true}")
# ==================== 实战技巧:高效数据加载优化 ====================
print("\n=== 4. 数据加载优化技巧 ===")
def optimize_dataloader():
"""DataLoader性能优化建议"""
tips = [
"1. num_workers设置:Linux/Mac设为CPU核心数(如os.cpu_count()),Windows建议设0(避免多进程路径问题)",
"2. pin_memory=True:GPU训练时开启,加速数据从CPU到GPU的传输",
"3. 预处理逻辑尽量放在Dataset的__getitem__中,利用num_workers并行处理",
"4. 大数据集建议使用DatasetFolder/ImageFolder(PyTorch内置),避免重复写__getitem__",
"5. 数据增强(如随机翻转)仅在训练集的transform中使用,验证/测试集禁用",
"6. drop_last=True:训练时开启,避免最后一个批次样本数不足导致的batch norm报错"
]
for tip in tips:
print(f" {tip}")
optimize_dataloader()
399

被折叠的 条评论
为什么被折叠?



