工业质检实战:从零构建基于ResUNet的钢板缺陷分割系统
在工业制造领域,质量检测一直是保障产品可靠性的核心环节。传统的质检方式依赖人工目视,不仅效率低下,而且容易因疲劳导致漏检误判。随着计算机视觉技术的成熟,基于深度学习的自动化缺陷检测方案正在彻底改变这一局面。Kaggle平台上著名的Severstal钢板缺陷检测竞赛,就是一个典型的工业视觉分割任务——它要求算法不仅要识别钢板表面是否存在缺陷,还要精确地定位出缺陷的轮廓与类别。
对于希望进入工业AI应用领域的开发者而言,这个项目堪称绝佳的实战沙盒。它不像一些学术数据集那样“干净”,而是充满了真实工业场景的挑战:数据极度不平衡、缺陷形态多变、标注格式特殊(RLE编码)。单纯调用一个现成的模型库很难取得好成绩,你必须深入理解数据特性、精心设计模型架构、并巧妙处理训练中的各种陷阱。
本文将带你从零开始,完整复现并深度解析一个基于TensorFlow 2.x的ResUNet解决方案。我们不会止步于代码的简单罗列,而是会深入每一个技术决策的背后逻辑:为什么选择ResUNet?如何处理棘手的RLE编码?面对样本不均衡该用什么损失函数?后处理有哪些技巧能提升最终分数?我将结合自己多次参与类似竞赛的经验,分享那些在官方文档里找不到的实战细节和调优心得。
1. 竞赛任务解析与数据准备
Severstal竞赛的目标是检测钢板表面的四类缺陷:1)纵向裂纹,2)横向裂纹,3)孔洞,4)表面污渍。提供的训练数据包含数千张高分辨率钢板图像(1600x256像素),以及对应的缺陷标注。标注信息以Run-Length Encoding(RLE)格式存储,这是一种高效的二值掩码压缩表示法。
1.1 理解RLE编码:工业数据存储的智慧
RLE编码对于处理大规模工业图像数据至关重要。想象一下,一张1600x256的图片,如果直接存储二值掩码(0表示背景,1表示缺陷),需要409,600个二进制值。而大多数图片中缺陷区域占比极小,这种存储方式极其浪费空间。RLE的原理是记录连续像素段的起始位置和长度。
竞赛提供的train.csv文件格式如下:
| ImageId_ClassId | EncodedPixels |
|---|---|
| 0002cc93b.jpg_1 | 29102 12 29346 24 29602 24 ... |
| 0002cc93b.jpg_2 | -1 |
| 0002cc93b.jpg_3 | 144512 5 144517 8 ... |
| 0002cc93b.jpg_4 | -1 |
EncodedPixels列即为RLE字符串。-1表示该图片在该类别下无缺陷。字符串中的数字成对出现:起始像素位置 长度。这里有一个关键细节:像素位置是按列优先(column-major)顺序展开的。这与我们常见的行优先(row-major)的numpy或OpenCV数组存储方式不同,解码时需要特别注意。
注意:RLE解码时,起始位置通常是从1开始计数的,而编程中的数组索引从0开始,因此解码时需要将起始位置减1。
下面是一个将RLE字符串解码为二维掩码矩阵的实用函数。我对其进行了优化,加入了错误处理和边界检查:
import numpy as np
def rle_to_mask(rle_string, height, width):
"""
将RLE编码字符串转换为二维掩码图像。
参数:
rle_string (str): RLE编码字符串,如 '29102 12 29346 24'
height (int): 掩码图像的高度
width (int): 掩码图像的宽度
返回:
numpy.ndarray: 形状为(height, width)的二值掩码,缺陷区域为1,背景为0
"""
# 处理无缺陷的情况
if rle_string == '-1' or pd.isna(rle_string):
return np.zeros((height, width), dtype=np.uint8)
# 将字符串分割并转换为整数列表
rle_numbers = list(map(int, rle_string.split()))
# 确保数字成对出现
if len(rle_numbers) % 2 != 0:
raise ValueError(f"RLE字符串 '{rle_string}' 包含奇数个数字,无法成对解析。")
# 创建一维数组(列优先展开)
flat_mask = np.zeros(height * width, dtype=np.uint8)
# 填充缺陷区域
for start, length in zip(rle_numbers[0::2], rle_numbers[1::2]):
# RLE起始位置从1开始,转换为从0开始的索引
start_idx = start - 1
# 确保索引不越界
if start_idx + length > len(flat_mask):
# 实际中可能会遇到,进行截断处理
length = len(flat_mask) - start_idx
flat_mask[start_idx:start_idx + length] = 1
# 重塑为二维数组(先按列重塑,再转置为按行)
mask = flat_mask.reshape(width, height).T # 注意:reshape(width, height)然后转置
return mask
对应的,将模型预测的掩码编码回RLE格式用于提交,也需要遵循相同的列优先规则:
def mask_to_rle(mask):
"""
将二维二值掩码转换为RLE编码字符串。
参数:
mask (numpy.ndarray): 二维二值数组,缺陷区域为1,背景为0
返回:
str: RLE编码字符串
"""
# 将掩码按列优先展平
pixels = mask.T.flatten()
# 在首尾添加0,便于检测边界变化
pixels = np.concatenate([[0], pixels, [0]])
# 找到值发生变化的位置(从0变1或从1变0)
runs = np.where(pixels[1:] != pixels[:-1])[0] + 1 # +1是因为前面添加了一个0
# runs中偶数索引位置是段落的开始位置,奇数索引位置是段落的结束位置
# 将结束位置转换为长度
runs[1::2] -= runs[::2]
# 转换为字符串,注意起始位置要加1(因为RLE从1开始计数)
rle_string = ' '.join(str(x) for x in runs)
return rle_string if rle_string.strip() else '-1'
1.2 数据探索与不平衡问题分析
加载数据后,第一件要做的事就是分析缺陷的分布情况。你会发现这是一个典型的极端类别不平衡问题:
import pandas as pd
import matplotlib.pyplot as plt
# 加载训练标签
train_df = pd.read_csv('/kaggle/input/severstal-steel-defect-detection/train.csv')
# 统计各类缺陷的数量
defect_stats = train_df.groupby('ClassId')['EncodedPixels'].apply(
lambda x: (x != '-1').sum()
).reset_index(name='count')
print("各类缺陷样本数量统计:")
print(defect_stats)
# 可视化
plt.figure(figsize=(10, 6))
bars = plt.bar(defect_stats['ClassId'].astype(str), defect_stats['count'])
plt.title('钢板缺陷类别分布', fontsize=14)
plt.xlabel('缺陷类别', fontsize=12)
plt.ylabel('样本数量', fontsize=12)
# 在柱状图上显示具体数值
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height + 5,
f'{int(height)}', ha='center', va='bottom')
plt.tight_layout()
plt.show()
在我的分析中,四类缺陷的分布大致如下(具体数值因数据集版本可能略有不同):
| 缺陷类别 | 描述 | 典型样本数量 | 占比 |
|---|---|---|---|
| 类别1 | 纵向裂纹 | ~700 | 约3% |
| 类别2 | 横向裂纹 | ~100 | 约0.5% |
| 类别3 | 孔洞 | ~1500 | 约7% |
| 类别4 | 表面污渍 | ~13000 | 约89.5% |
这种不平衡会带来严重问题:模型会倾向于预测占主导的类别4,而忽略罕见的类别1和2。在工业质检中,漏检一个危险的裂纹(类别1或2)可能比误检一个污渍(类别4)后果严重得多。因此,我们需要在数据加载和损失函数设计上采取特殊策略。
2. 构建高效的数据管道
在工业视觉任务中,数据预处理管道的效率直接影响模型迭代速度。我们需要处理高分辨率图像、RLE解码、数据增强,同时还要应对内存限制。
2.1 自定义数据生成器
使用tf.keras.utils.Sequence创建自定义数据生成器是处理大型图像数据集的最佳实践。它支持多进程数据加载,并且能在每个epoch结束后自动打乱数据。
import tensorflow as tf
import cv2
import os
from sklearn.model_selection import train_test_split
class SteelDefectGenerator(tf.keras.utils.Sequence):
"""
钢板缺陷检测数据生成器。
支持实时数据增强、RLE解码、样本权重调整。
"""
def __init__(self, image_ids, rle_dict, image_dir, batch_size=8,
img_height=256, img_width=512, shuffle=True,
augment=False, class_weights=None):
"""
初始化数据生成器。
参数:
image_ids (list): 图片ID列表
rle_dict (dict): 以'ImageId_ClassId'为键,RLE字符串为值的字典
image_dir (str): 图片文件目录
batch_size (int): 批次大小
img_height, img_width (int): 目标图像尺寸
shuffle (bool): 是否在每个epoch后打乱数据
augment (bool): 是否启用数据增强
class_weights (dict): 各类别的样本权重,用于处理不平衡
"""
self.image_ids = image_ids
self.rle_dict = rle_dict
self.image_dir = image_dir
self.batch_size = batch_size
self.img_h = img_height
self.img_w = img_width
self.shuffle = shuffle
self.augment = augment
self.class_weights = class_weights or {1: 1.0, 2: 1.0, 3: 1.0, 4: 1.0}
# 如果启用增强,创建增强管道
if self.augment:
self.augmentation = self._create_augmentation_pipeline()
self.on_epoch_end()
def __len__(self):
"""返回每个epoch的批次数"""
return int(np.ceil(len(self.image_ids) / self.batch_size))
def __getitem__(self, index):
"""生成一个批次的数据"""
batch_ids = self.image_ids[index * self.batch_size:
(index + 1) * self.bapshot_size]
# 初始化批次数组
batch_images = np.zeros((len(batch_ids), self.img_h, self.img_w, 1),
dtype=np.float32)
batch_masks = np.zeros((len(batch_ids), self.img_h, self.img_w, 4),
dtype=np.float32)
batch_weights = np.ones((len(batch_ids), 4), dtype=np.float32)
for i, img_id in enumerate(batch_ids):
# 加载并预处理图像
img_path = os.path.join(self.image_dir, img_id)
image = self._load_and_preprocess_image(img_path)
# 加载并预处理掩码
mask = self._load_mask_for_image(img_id)
# 数据增强
if self.augment and np.random.random() > 0.5:
image, mask = self._apply_augmentation(image, mask)
batch_images[i] = image
batch_masks[i] = mask
# 为每个样本的每个类别设置权重
for class_idx in range(4):
if np.any(mask[..., class_idx] > 0):
batch_weights[i, class_idx] = self.class_weights[class_idx + 1]
return batch_images, batch_masks, batch_weights
def _

5589

被折叠的 条评论
为什么被折叠?



