目录
论文介绍
题目:
MAGNet: Multi-scale Awareness and Global fusion Network for RGB-D salient object detection | KBS
论文地址:
链接: https://www.sciencedirect.com/science/article/abs/pii/S0950705124007603
创新点
- 设计了一种轻量化的多模态融合网络:
- 提出了一种名为MAGNet(Multi-scale Awareness and Global fusion Network)的网络,用于RGB-D显著性目标检测。
- 通过16.1M的参数和9.9G的FLOPs实现了与先进方法相当的检测性能,同时大幅度减少了模型的复杂度。
- 模块创新:
- 多尺度感知融合模块(MAFM):充分利用低层特征图中的纹理信息和边缘信息,减少计算复杂度。
- 全局融合模块(GFM):结合注意力机制与卷积神经网络,增强高层特征图的语义信息。
- 多级卷积模块(MCM):用于逐步解码融合特征图,生成精细的预测结果。
- 跨模态特征融合:
- 在低层特征中,使用MAFM实现RGB和深度图特征的跨模态融合,以减少复杂背景和低光条件下的干扰。
- 在高层特征中,通过GFM设计全局融合,实现对RGB与深度图语义信息的全面整合。
- 性能优化与验证:
- 在6个公共数据集上进行实验,结果表明MAGNet不仅在精确度上优于现有方法,而且在参数量和计算复杂度上显著减少。
- 提供了一个轻量化版本MAGNet-S,进一步验证了其适应低计算资源环境的能力。
方法
模型总体架构
双流编码器
- RGB图像由 SMT(Swapped Mix Transformer) 提取多级特征。
- 深度图像由 MobileNetV2 提取多级特征。
- 这种设计结合了Transformer的全局感知能力和轻量化网络的高效性。
特征融合模块
- 低层特征融合:通过**多尺度感知融合模块(MAFM)**实现,融合RGB和深度图像的低层特征,充分利用纹理和边缘信息。
- 高层特征融合:通过**全局融合模块(GFM)**实现,将RGB和深度图像的语义信息进行全局关联和融合。
解码器
- 使用多级卷积模块(MCM),逐步将融合后的特征图解码为显著性目标图。
- MCM通过层级特征整合生成精细的显著性目标预测结果。

核心模块描述
多尺度感知融合模块(MAFM)
- 低层特征融合,结合了深度可分离卷积(DW)、点卷积(PW)以及多头混合卷积(MHMC)。
- 目的:降低计算复杂度的同时增强特征图的空间相关性。

全局融合模块(GFM)
- 高层特征融合,采用注意力机制结合卷积运算。
- 特点:
- 融合RGB和深度特征的全局语义信息。
- 通过注意力机制有效捕获跨模态的全局关联。

多级卷积模块(MCM)
- 解码器部分,每级特征通过上采样、深度可分离卷积和逐点卷积逐步整合。
- 目标:从低层到高层逐步恢复图像细节,生成高质量的显著性目标预测图。

即插即用模块作用
MCM作为一个即插即用模块:
特征融合与逐步解码
- MCM通过逐级特征整合,结合高层语义信息和低层细节信息,从而逐步恢复特征图中的细节。
- 该模块可以有效减少特征丢失,同时保留丰富的细节。
降低计算复杂度
- MCM中使用了深度可分离卷积(Depth-wise Convolution)和逐点卷积(Point-wise Convolution),极大降低了计算量。
- 对于资源受限的场景,MCM的轻量化设计显得尤为重要。
提升多层次特征的表达能力
- 高层特征中包含的全局语义信息可以通过MCM逐级整合至低层特征,补充细节。
- 低层特征可以帮助更精确地定位边缘和局部区域。
消融实验结果

- 对比不同主干网络组合对模型性能的影响:采用轻量化的MobileNetV2可以减少参数和计算量,但在RGB细节提取上表现不足。采用SMT作为RGB图像主干网络显著提升了模型性能。
- 结论:SMT和MobileNetV2的结合在计算效率和性能之间达成了良好平衡。

- 对比是否使用MAFM以及其他替代模块(PI和CMFM):添加MAFM后模型性能显著提升,MAE指标在各数据集上降低。与其他模块相比,MAFM以较少的参数实现了更高的检测精度。
- 结论:MAFM在低层特征融合中有效整合了RGB和深度特征,尤其在复杂场景下提升了模型的鲁棒性。

- 对比不同分辨率下的模型性能和计算量:提高分辨率可以提升检测精度,但会显著增加计算量(FLOPs)和降低推理速度。最终选择384×384作为平衡点,兼顾性能和效率。
- 结论:分辨率对模型性能和效率有直接影响,应根据应用需求选择合适的输入尺寸。

- 对比是否使用MHMC以及替代方法(单层卷积):使用MHMC的模型在多个数据集上的性能均优于其他方法。
- 结论:MHMC在MAFM中通过捕获多尺度的相关性增强了RGB和深度特征的融合。

- 对比是否使用GFM以及替代方法(SCA和AF模块):添加GFM后,模型在多数据集上均有性能提升。GFM的性能略优于其他方法(如SCA和AF),尤其在复杂场景中表现更好。
- 结论:GFM能够更有效地融合RGB和深度特征的全局语义信息。
即插即用模块代码
import torch.nn as nn
import torch
import torch.nn.functional as F
# 论文:MAGNet: Multi-scale Awareness and Global fusion Network for RGB-D salient object detection | KBS
# 论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0950705124007603
# github地址:https://github.com/mingyu6346/MAGNet
TRAIN_SIZE = 384
class MCM(nn.Module):
def __init__(self, inc, outc):
super().__init__()
self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.rc = nn.Sequential(
nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc),
nn.BatchNorm2d(inc),
nn.GELU(),
nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1),
nn.BatchNorm2d(outc),
nn.GELU()
)
self.predtrans = nn.Sequential(
nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, padding=1, groups=outc),
nn.BatchNorm2d(outc),
nn.GELU(),
nn.Conv2d(in_channels=outc, out_channels=1, kernel_size=1)
)
self.rc2 = nn.Sequential(
nn.Conv2d(in_channels=outc * 2, out_channels=outc * 2, kernel_size=3, padding=1, groups=outc * 2),
nn.BatchNorm2d(outc * 2),
nn.GELU(),
nn.Conv2d(in_channels=outc * 2, out_channels=outc, kernel_size=1, stride=1),
nn.BatchNorm2d(outc),
nn.GELU()
)
def forward(self, x1, x2):
x2_upsample = self.upsample2(x2) # 上采样
x2_rc = self.rc(x2_upsample) # 减少通道数
shortcut = x2_rc
x_cat = torch.cat((x1, x2_rc), dim=1) # 拼接
x_forward = self.rc2(x_cat) # 减少通道数2
x_forward = x_forward + shortcut
pred = F.interpolate(self.predtrans(x_forward), TRAIN_SIZE, mode="bilinear", align_corners=True) # 预测图
return pred, x_forward
if __name__ == '__main__':
inc = 64 # 输入通道数
outc = 32 # 输出通道数
mcm = MCM(inc=inc, outc=outc)
x1 = torch.randn(1, outc, 96, 96) # Batch size=1, Channels=outc, Height=96, Width=96
x2 = torch.randn(1, inc, 48, 48) # Batch size=1, Channels=inc, Height=48, Width=48
pred, x_forward = mcm(x1, x2)
print(x1.size())
print(x2.size())
print(pred.size())
print(x_forward.size())
1878

被折叠的 条评论
为什么被折叠?



