Deformable Attention
1. 为什么需要Deformable Attention?
前面我们已经学习了标准Attention公式,标准的Attention机制可以用以下公式表示:
Attention(Q,K,V)=softmax(QKTdk)V
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
Attention(Q,K,V)=softmax(dkQKT)V
其中:
- QQQ:查询矩阵(Query)
- KKK:键矩阵(Key)
- VVV:值矩阵(Value)
- dkd_kdk:键向量的维度
- dk\sqrt{d_k}dk:缩放因子,防止点积过大导致softmax梯度消失
我们也举过例子,我们就拿Cross Attention 来说,相当于我们带着找几本自动驾驶方面书籍的Query,去图书馆里的Key/Value中去查找相关书籍。
但是,Attention是怎么做的,就是用我们的query去图书馆中的每一本书的key/value去比对,然后看哪几本符合我们要找的。
大家可以想象一下,这得多傻啊。我们直接去标着“自动驾驶类”的书架上去找不就完了吗?
2. Deformable Attention简单理解
Deformable Attention和标准的Attention不一样,不去进行全局密集的注意力,而是在一些参考点周围去进行Attention。就像上面我说的,我们去图书馆找自动驾驶书籍,直接去标着“智驾类”的书架上找就行了。
前面query- based的BEV特征Encoder中也讲过,他是拿BEV空间中每个栅格内的query去和特征图像上每个像素去Attention,可是这种全局密集attention有必要吗?比如左前视角看到一辆车,那左前部分的BEV query就去和左前视角的图像一部分区域去做Attention即可。
标准Attention是数据驱动,就是硬找。其实我们是有一些先验信息可以拿到的,比如相机的内外参。通过内外参我们就可以找到BEV某个栅格中query能对应上图像中的哪个区域。
Deformable Attention是标准Attention的改进版本,它允许注意力机制关注非规则位置,下面就是他的公式:
DeformableAttention(q,pq,x)=∑m=1MWm[∑k=1KAmqk⋅x(pq+Δpmqk)]
\text{DeformableAttention}(q, p_q, x) = \sum_{m=1}^{M} W_m \left[ \sum_{k=1}^{K} A_{mqk} \cdot x(p_q + \Delta p_{mqk}) \right]
DeformableAttention(q,pq,x)=m=1∑MWm[k=1∑KAmqk⋅x(pq+Δpmqk)]
其中:
- qqq:查询索引
- pqp_qpq:参考点位置
- xxx:输入特征图
- MMM:注意力头的数量
- KKK:采样的键值对数量
- AmqkA_{mqk}Amqk:第mmm个头中第kkk个采样点的注意力权重
- Δpmqk\Delta p_{mqk}Δpmqk:可学习的偏移量
3. Deformable Attention详解
Deformable Attention 新增两大独有模块:参考点 Reference Point、可学习偏移 Offset。
整体流程 6 步:
- 为每个 Query 生成参考点p(几何先验:BEV 栅格通过相机外内参投影到图像,得到初始像素坐标)
- 网络预测KKK组可学习偏移(Δpqk)(\Delta p_{qk})(Δpqk),修正投影误差
- 计算KKK个采样坐标:(pqk=p+Δpqk)(p_{qk}=p+\Delta p_{qk})(pqk=p+Δpqk)(浮点坐标,不在整数像素上)
- 对浮点坐标用双线性插值从图像特征图取出K组V特征
- 网络预测KKK个采样点的注意力权重(Aqk)(A_{qk})(Aqk)(替代标准 QK 点积相似度)
- 对采样特征加权求和,输出 Query 最终特征

我们照着这张图以及下面的公式看上面的流程。
DeformableAttention(q,pq,x)=∑m=1MWm[∑k=1KAmqk⋅x(pq+Δpmqk)] \text{DeformableAttention}(q, p_q, x) = \sum_{m=1}^{M} W_m \left[ \sum_{k=1}^{K} A_{mqk} \cdot x(p_q + \Delta p_{mqk}) \right] DeformableAttention(q,pq,x)=m=1∑MWm[k=1∑KAmqk⋅x(pq+Δpmqk)]
其中:
- qqq:查询索引
- pqp_qpq:参考点位置
- xxx:输入特征图
- MMM:注意力头的数量
- KKK:采样的键值对数量
- AmqkA_{mqk}Amqk:第mmm个头中第kkk个采样点的注意力权重
- Δpmqk\Delta p_{mqk}Δpmqk:可学习的偏移量
- 先来看输入:query/参考点/Value对应公式里的(q,pq,x)(q, p_q, x)(q,pq,x), 后面会理解为什么没有了Key。
这里的参考点就是我们通过内外参把当前BEV栅格位置投影到图像上的位置。 - 再来看 x(pq+Δpmqk)x(p_q + \Delta p_{mqk})x(pq+Δpmqk) ,pqp_qpq是参考点,然后Δpmqk\Delta p_{mqk}Δpmqk是在参考点附近做的偏移。我们看这个偏移是怎么来的,他是query经过一个Linear线性网络预测出来的。这里用的多头Attention,预测了多个头的偏移。外面这个函数x(∗)x(*)x(∗)是双线性插值得到Value。有几个头就会产生几组Value。
- 接着我们再来看这个AmqkA_{mqk}Amqk,这就是前面每个头的Value需要的权重。之前是Q∗KQ*KQ∗K得到,现在是通过一个linear+Soft Max得到。同样是几个头得几组权重。然后再和前面的几组Value聚合,最后经过一个linear输出最终的结果。
4. 四、PyTorch 极简伪代码(对照上面流程,逐行注释)
import torch
import torch.nn as nn
import torch.nn.functional as F
class SingleHeadDeformAttn(nn.Module):
def __init__(self, d_model, n_sample=8):
super().__init__()
self.d_model = d_model
self.K = n_sample # 每个Query采样8个点
# 1. 预测偏移 + 注意力权重 MLP
self.offset_mlp = nn.Linear(d_model, self.K * 2) # K组(dx, dy)
self.attn_mlp = nn.Linear(d_model, self.K) # K个注意力权重
# Value投影矩阵
self.w_v = nn.Linear(d_model, d_model)
def bilinear_sample(self, feat_map, sample_coords):
"""
feat_map: [B, C, H, W] 图像特征图
sample_coords: [B, N, K, 2] 归一化采样坐标(0~1)
return: [B, N, K, C] 采样得到的特征
"""
B, C, H, W = feat_map.shape
N = sample_coords.shape[1]
# grid_sample要求坐标范围[-1,1],转换归一化坐标
grid = sample_coords * 2 - 1
# F.grid_sample 批量采样
sampled_feat = F.grid_sample(
feat_map, grid, mode="bilinear", align_corners=False
) # [B, C, N, K]
return sampled_feat.permute(0,2,3,1) # [B,N,K,C]
def forward(self, query, feat_map, ref_points_norm):
"""
query: [B, N, d_model] BEV Query序列
feat_map: [B, C, H, W] 图像K/V特征图
ref_points_norm: [B, N, 2] 归一化后的参考点坐标(0~1)
"""
B, N, _ = query.shape
# 1. 预测K组偏移量 [B,N,K,2]
offset = self.offset_mlp(query).view(B, N, self.K, 2)
# 参考点广播 + 偏移得到最终采样坐标
ref_expand = ref_points_norm.unsqueeze(2).repeat(1,1,self.K,1)
sample_coords = ref_expand + offset
# 2. 双线性插值采样图像特征 [B,N,K,d_model]
v_feat = self.w_v(feat_map.permute(0,2,3,1)).permute(0,3,1,2)
sampled_v = self.bilinear_sample(v_feat, sample_coords)
# 3. 预测注意力权重并归一化 [B,N,K]
attn_weight = self.attn_mlp(query).view(B, N, self.K)
attn_weight = F.softmax(attn_weight, dim=-1)
# 4. 加权求和得到输出 [B,N,d_model]
output = torch.sum(attn_weight.unsqueeze(-1) * sampled_v, dim=2)
return output
# 多头可变形注意力封装
class MultiHeadDeformAttn(nn.Module):
def __init__(self, d_model, n_head=8, n_sample=8):
super().__init__()
self.d_model = d_model
self.n_head = n_head
self.d_k = d_model // n_head
# 构建多个独立单头
self.heads = nn.ModuleList([
SingleHeadDeformAttn(self.d_k, n_sample) for _ in range(n_head)
])
self.w_out = nn.Linear(d_model, d_model)
def forward(self, query, feat_map, ref_points_norm):
B, N, _ = query.shape
# 分头
query_split = torch.split(query, self.d_k, dim=-1)
head_outputs = []
for i, head in enumerate(self.heads):
head_out = head(query_split[i], feat_map, ref_points_norm)
head_outputs.append(head_out)
# 拼接多头输出
concat = torch.cat(head_outputs, dim=-1)
return self.w_out(concat)
2572

被折叠的 条评论
为什么被折叠?



