【端到端智驾基础】3. 为什么要用Deformable Attention?

Deformable Attention在这里插入图片描述

1. 为什么需要Deformable Attention?

前面我们已经学习了标准Attention公式,标准的Attention机制可以用以下公式表示:
Attention(Q,K,V)=softmax(QKTdk)V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
其中:

  • QQQ:查询矩阵(Query)
  • KKK:键矩阵(Key)
  • VVV:值矩阵(Value)
  • dkd_kdk:键向量的维度
  • dk\sqrt{d_k}dk:缩放因子,防止点积过大导致softmax梯度消失

我们也举过例子,我们就拿Cross Attention 来说,相当于我们带着找几本自动驾驶方面书籍的Query,去图书馆里的Key/Value中去查找相关书籍。
但是,Attention是怎么做的,就是用我们的query去图书馆中的每一本书的key/value去比对,然后看哪几本符合我们要找的。
大家可以想象一下,这得多傻啊。我们直接去标着“自动驾驶类”的书架上去找不就完了吗?

2. Deformable Attention简单理解

Deformable Attention和标准的Attention不一样,不去进行全局密集的注意力,而是在一些参考点周围去进行Attention。就像上面我说的,我们去图书馆找自动驾驶书籍,直接去标着“智驾类”的书架上找就行了。
前面query- based的BEV特征Encoder中也讲过,他是拿BEV空间中每个栅格内的query去和特征图像上每个像素去Attention,可是这种全局密集attention有必要吗?比如左前视角看到一辆车,那左前部分的BEV query就去和左前视角的图像一部分区域去做Attention即可。

标准Attention是数据驱动,就是硬找。其实我们是有一些先验信息可以拿到的,比如相机的内外参。通过内外参我们就可以找到BEV某个栅格中query能对应上图像中的哪个区域。

Deformable Attention是标准Attention的改进版本,它允许注意力机制关注非规则位置,下面就是他的公式:
DeformableAttention(q,pq,x)=∑m=1MWm[∑k=1KAmqk⋅x(pq+Δpmqk)] \text{DeformableAttention}(q, p_q, x) = \sum_{m=1}^{M} W_m \left[ \sum_{k=1}^{K} A_{mqk} \cdot x(p_q + \Delta p_{mqk}) \right] DeformableAttention(q,pq,x)=m=1MWm[k=1KAmqkx(pq+Δpmqk)]

其中:

  • qqq:查询索引
  • pqp_qpq:参考点位置
  • xxx:输入特征图
  • MMM:注意力头的数量
  • KKK:采样的键值对数量
  • AmqkA_{mqk}Amqk:第mmm个头中第kkk个采样点的注意力权重
  • Δpmqk\Delta p_{mqk}Δpmqk:可学习的偏移量

3. Deformable Attention详解

Deformable Attention 新增两大独有模块:参考点 Reference Point、可学习偏移 Offset
整体流程 6 步:

  1. 为每个 Query 生成参考点p(几何先验:BEV 栅格通过相机外内参投影到图像,得到初始像素坐标)
  2. 网络预测KKK组可学习偏移(Δpqk)(\Delta p_{qk})(Δpqk),修正投影误差
  3. 计算KKK个采样坐标:(pqk=p+Δpqk)(p_{qk}=p+\Delta p_{qk})(pqk=p+Δpqk)(浮点坐标,不在整数像素上)
  4. 对浮点坐标用双线性插值从图像特征图取出K组V特征
  5. 网络预测KKK个采样点的注意力权重(Aqk)(A_{qk})(Aqk)(替代标准 QK 点积相似度)
  6. 对采样特征加权求和,输出 Query 最终特征
    在这里插入图片描述
    我们照着这张图以及下面的公式看上面的流程。
    DeformableAttention(q,pq,x)=∑m=1MWm[∑k=1KAmqk⋅x(pq+Δpmqk)] \text{DeformableAttention}(q, p_q, x) = \sum_{m=1}^{M} W_m \left[ \sum_{k=1}^{K} A_{mqk} \cdot x(p_q + \Delta p_{mqk}) \right] DeformableAttention(q,pq,x)=m=1MWm[k=1KAmqkx(pq+Δpmqk)]

其中:

  • qqq:查询索引
  • pqp_qpq:参考点位置
  • xxx:输入特征图
  • MMM:注意力头的数量
  • KKK:采样的键值对数量
  • AmqkA_{mqk}Amqk:第mmm个头中第kkk个采样点的注意力权重
  • Δpmqk\Delta p_{mqk}Δpmqk:可学习的偏移量
  1. 先来看输入:query/参考点/Value对应公式里的(q,pq,x)(q, p_q, x)(q,pq,x), 后面会理解为什么没有了Key。
    这里的参考点就是我们通过内外参把当前BEV栅格位置投影到图像上的位置。
  2. 再来看 x(pq+Δpmqk)x(p_q + \Delta p_{mqk})x(pq+Δpmqk)pqp_qpq是参考点,然后Δpmqk\Delta p_{mqk}Δpmqk是在参考点附近做的偏移。我们看这个偏移是怎么来的,他是query经过一个Linear线性网络预测出来的。这里用的多头Attention,预测了多个头的偏移。外面这个函数x(∗)x(*)x()是双线性插值得到Value。有几个头就会产生几组Value。
  3. 接着我们再来看这个AmqkA_{mqk}Amqk,这就是前面每个头的Value需要的权重。之前是Q∗KQ*KQK得到,现在是通过一个linear+Soft Max得到。同样是几个头得几组权重。然后再和前面的几组Value聚合,最后经过一个linear输出最终的结果。

4. 四、PyTorch 极简伪代码(对照上面流程,逐行注释)

import torch
import torch.nn as nn
import torch.nn.functional as F

class SingleHeadDeformAttn(nn.Module):
    def __init__(self, d_model, n_sample=8):
        super().__init__()
        self.d_model = d_model
        self.K = n_sample  # 每个Query采样8个点
        
        # 1. 预测偏移 + 注意力权重 MLP
        self.offset_mlp = nn.Linear(d_model, self.K * 2)  # K组(dx, dy)
        self.attn_mlp = nn.Linear(d_model, self.K)         # K个注意力权重
        # Value投影矩阵
        self.w_v = nn.Linear(d_model, d_model)

    def bilinear_sample(self, feat_map, sample_coords):
        """
        feat_map: [B, C, H, W] 图像特征图
        sample_coords: [B, N, K, 2] 归一化采样坐标(0~1)
        return: [B, N, K, C] 采样得到的特征
        """
        B, C, H, W = feat_map.shape
        N = sample_coords.shape[1]
        # grid_sample要求坐标范围[-1,1],转换归一化坐标
        grid = sample_coords * 2 - 1
        # F.grid_sample 批量采样
        sampled_feat = F.grid_sample(
            feat_map, grid, mode="bilinear", align_corners=False
        )  # [B, C, N, K]
        return sampled_feat.permute(0,2,3,1)  # [B,N,K,C]

    def forward(self, query, feat_map, ref_points_norm):
        """
        query: [B, N, d_model] BEV Query序列
        feat_map: [B, C, H, W] 图像K/V特征图
        ref_points_norm: [B, N, 2] 归一化后的参考点坐标(0~1)
        """
        B, N, _ = query.shape
        
        # 1. 预测K组偏移量 [B,N,K,2]
        offset = self.offset_mlp(query).view(B, N, self.K, 2)
        # 参考点广播 + 偏移得到最终采样坐标
        ref_expand = ref_points_norm.unsqueeze(2).repeat(1,1,self.K,1)
        sample_coords = ref_expand + offset

        # 2. 双线性插值采样图像特征 [B,N,K,d_model]
        v_feat = self.w_v(feat_map.permute(0,2,3,1)).permute(0,3,1,2)
        sampled_v = self.bilinear_sample(v_feat, sample_coords)

        # 3. 预测注意力权重并归一化 [B,N,K]
        attn_weight = self.attn_mlp(query).view(B, N, self.K)
        attn_weight = F.softmax(attn_weight, dim=-1)
        
        # 4. 加权求和得到输出 [B,N,d_model]
        output = torch.sum(attn_weight.unsqueeze(-1) * sampled_v, dim=2)
        return output

# 多头可变形注意力封装
class MultiHeadDeformAttn(nn.Module):
    def __init__(self, d_model, n_head=8, n_sample=8):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head
        # 构建多个独立单头
        self.heads = nn.ModuleList([
            SingleHeadDeformAttn(self.d_k, n_sample) for _ in range(n_head)
        ])
        self.w_out = nn.Linear(d_model, d_model)

    def forward(self, query, feat_map, ref_points_norm):
        B, N, _ = query.shape
        # 分头
        query_split = torch.split(query, self.d_k, dim=-1)
        head_outputs = []
        for i, head in enumerate(self.heads):
            head_out = head(query_split[i], feat_map, ref_points_norm)
            head_outputs.append(head_out)
        # 拼接多头输出
        concat = torch.cat(head_outputs, dim=-1)
        return self.w_out(concat)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

宛如新生

转发即鼓励,打赏价更高!哈哈。

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值