Train Once, Reuse Everywhere: Generalizable Implicit In-Context Learning by Routing Attention
开源地址:https://github.com/Lijiaqian1/In-Context-Routing.git
arXiv:2509.22854v2 | ICML 2026
一、背景
LLM通过上下文学习ICL(在查询前插入少量标注示例作为上下文演示ICDs),使模型无需更新参数,即能够适应新任务,但传统ICL存在两大痛点:(1) 插入示例增加序列长度,导致推理成本剧增;(2) 性能脆弱,高度依赖示例的顺序和格式。
为解决这些问题,学术界提出了隐式ICL,即将示例转化为向量注入模型内部。然而,现有的隐式ICL方法大多是在注意力计算完成后,向残差流中注入偏移向量。这种“事后补救”的方式缺乏理论支撑,且泛化能力差,无法迁移到未见过的域外(OOD)任务。
基于此,论文提出疑问:“能否找到一种通用的潜在模式支持隐式ICL,从而超越现有范式,实现跨场景的无缝泛化?”
为此,作者进一步深入研究attention space以识别、利用ICL模式,形式化地分析ICL如何被分解和嵌入到Attention Logits中的,提出ICR,该方法提取跨任务ICL模式,利用路由器将其作为一个低秩权重组合进行合并,以任务自适应的方式引导注意力计算,实现“训练一次,到处复用”。
二、方法
(一)注意力路由 Attention Routing
1. 预备知识:基于向量的隐式ICL及其挑战
- 显式的ICL提示词将p=[D,xq]p=[D,x_q]p=[D,xq]输入给LLM时,p通常由标注好的示例ICDs:D={(xi,yi)}i=1nD=\{(x_i,y_i)\}^n_{i=1}D={(xi,yi)}i=1n和一个查询采样xqx_qxq构成。
- 基于向量的隐式ICL方法,认为ICDs可被视为对Transformer的多注意头MHA在零样本下的输出进行叠加修改,一种普遍的方法是将ICDs产生的激活差值作为偏移向量注入到零样本隐藏状态中。在标准Transformer中,第 lll 层的注意力输出 hlh^lhl 由多头注意力(MHA)计算得出,M是掩码:hl=Concath(softmax(Al,h)Vl,h)=Concath(softmax(Ql,hKl,h⊤dk+M)Vl,h)h^l=Concat_h(softmax(A^{l,h})V^{l,h})=Concat_h(softmax(\frac{Q^{l,h}K^{l,h \top}}{\sqrt{d_k}}+M)V^{l,h})hl=Concath(softmax(Al,h)Vl,h)=Concath(softmax(dkQl,hKl,h⊤+M)Vl,h)
现有的隐式ICL方法(如ICV, I2CL)通过在注意力输出后注入一个偏移向量 VshiftV_{shift}Vshift 来模拟ICL效果,βl\beta^lβl控制偏移程度:
h~tl=htl+βl⋅Vshiftl \tilde{h}_t^l = h_t^l + \beta_l \cdot V_{shift}^l h~tl=htl+βl⋅Vshiftl
挑战: 这种方法是在注意力聚合后进行事后干预,没有改变底层 QQQ 和 KKK 的交互几何结构,导致泛化能力受限。
2. 注意力路由的提出与形式化推导
- 作者认为,调整注意力空间内的匹配结构更能反映ICL的机制,Attention Logits由query-key的交互来驱动,将每个ICL提示的最后一个token视为整合上下文信息的整合点。通过分析其query和key的投影结果,能够捕捉到各类任务中因引入ICDs而产生的系统性变化。这些变化会形成一个低维子空间,用以表征具备泛化能力的ICL动态特征。
- 为获得此子空间:
- 获取PIDs(主ICL方向):通过将ICL提示词输入LLM,可获得一个跨领域的高维度注意力表征,并收集各领域最后一个token的Q和K,形成两个ICL基,对ICL的 QQQ 和 KKK 基矩阵采用PCA,获得每层的主要ICL方向(Principal ICL Directions PIDs)矩阵 Uql,Ukl∈Rd×rU_q^l, U_k^l \in \mathbb{R}^{d \times r}Uql,Ukl∈Rd×r。
- 构建低秩偏置:定义路由向量αl∈Rr\alpha^l \in \mathbb{R}^rαl∈Rr,该向量为第lll层的PIDs分配权重。αl\alpha_lαl控制每个PID调节注意力的作用强度。在零样本推理时,模型原本将MHA中每个头的Q、K拼接起来得到 Qzsl,Kzsl∈Rd×rQ_{zs}^l, K_{zs}^l \in \mathbb{R}^{d \times r}Qzsl,Kzsl∈Rd×r,而这里引入路由向量对Attention Logits进行低秩调制,进而将注意力动态偏向提取出的PIDs:
ΔAl=(QzslUql)diag(αl)(KzslUkl)⊤∈RT×T \Delta A^l = (Q_{zs}^l U_q^l) \text{diag}(\alpha^l) (K_{zs}^l U_k^l)^\top \in \mathbb{R}^{T \times T} ΔAl=(QzslUql)diag(αl)(KzslUkl)⊤∈RT×T
这里每层的偏置ΔAl\Delta A^lΔAl被第l层的所有头共享,所以每个注意力头变为:
A~l,h=Al,h+ΔAl \tilde{A}^{l,h} = A^{l,h} + \Delta A^l A~l,h=Al,h+ΔAl
核视角解释:这等价于将注意力机制的线性核 K0(q,k)=q⊤kK_0(q,k)=q^\top kK0(q,k)=q⊤k 重参数化为 Kα(q,k)=q⊤Ml(αl)kK_\alpha(q,k) = q^\top M^l(\alpha^l) kKα(q,k)=q⊤Ml(αl)k,其中 Ml=Id+Uqldiag(αl)Ukl⊤M^l = I_d + U_q^l \text{diag}(\alpha^l) {U_k^l}^\topMl=Id+Uqldiag(αl)Ukl⊤。这证明路由是对核空间的结构性低秩修改。
PCA:因为原本的Q、K向量基维度高,且包含大量与ICL无关的信息,所以使用PCA,找到一组正交的新基(主成分),让原始向量在这组新基上的投影方差最大,用最少的维度,保留原始数据最多的信息。从所有ICL样本的Q/K向量中,过滤噪声、提取出对 ICL 行为贡献最大的核心方向。
3. PIDs如何捕获通用ICL模式(理论证明)
为什么PID能提取出“通用”的ICL模式?作者利用混合尖峰协方差模型进行了数学证明:
- 将单个领域 ddd 下的Q矩阵协方差分解为(the Spiked Covariance Model形式):ΣQ(d)=SqΛqSq⊤⏟共享子空间+Bq,dΓq,dBq,d⊤⏟特定领域子空间+σ2I⏟噪声\Sigma_Q^{(d)} = \underbrace{S_q \Lambda_q S_q^\top}_{共享子空间} + \underbrace{B_{q,d} \Gamma_{q,d} B_{q,d}^\top}_{特定领域子空间} + \underbrace{\sigma^2 I}_{噪声}ΣQ(d)=共享子空间SqΛqSq⊤+特定领域子空间Bq,dΓq,dBq,d⊤+噪声σ2I
- 对D个领域的协方差进行合并:Σ^Q=1N∑d=1D∑i∈DdQiQi⊤,N=∑d=1D∣Dd∣\widehat{\Sigma}_Q = \frac{1}{N} \sum_{d=1}^D \sum_{i \in \mathcal{D}_d} Q_i Q_i^\top, \qquad N = \sum_{d=1}^D |\mathcal{D}_d|ΣQ=N1d=1∑Di∈Dd∑QiQi⊤,N=d=1∑D∣Dd∣其混合尖峰形式(the mixed spiked form)为:E[Σ^Q]=SqΛqSq⊤+σ2I+1N∑d=1D∣Dd∣Bq,dΓq,dBq,d⊤\mathbb{E}\left[\widehat{\Sigma}_Q\right] = S_q \Lambda_q S_q^\top + \sigma^2 I + \frac{1}{N} \sum_{d=1}^D |\mathcal{D}_d| B_{q,d} \Gamma_{q,d} B_{q,d}^\topE[ΣQ]=SqΛqSq⊤+σ2I+N1d=1∑D∣Dd∣Bq,dΓq,dBq,d⊤当跨多个领域合并协方差 Σ^Q\hat{\Sigma}_QΣ^Q 时,特异子空间因方向不一致互相抵消,趋于各向同性噪声;而共享子空间 SqΛqSq⊤S_q \Lambda_q S_q^\topSqΛqSq⊤ 不断累积。
- 在Transformer中,注意力分数的计算核心是点积:A=QK⊤A = Q K^\topA=QK⊤。这意味着序列中两个Token之间的注意力权重,取决于它们的Query向量和Key向量在隐空间中的几何对齐程度。当收集大量ICL示例输入模型,并提取最后一个Token的 QQQ 和 KKK 向量时,这些向量并不是杂乱无章的。它们在隐空间中形成了一定的分布。协方差矩阵(ΣQ,ΣK\Sigma_Q, \Sigma_KΣQ,ΣK)正是用来描述这种分布形状的。
总结来说:Q和K的协方差矩阵,刻画了模型在进行上下文学习(ICL)时,内部信息流动的“底层几何结构”。- 既然协方差代表了注意力的路由结构,那么不同任务的协方差肯定不一样,怎么保证提取出来的PIDs是跨任务“通用”的,而不是某个任务特有的?作者借用了统计学中的混合Spiked协方差模型,证明通用项 SqΛqSq⊤S_q \Lambda_q S_q^\topSqΛqSq⊤:因为 SqS_qSq 在每个领域都存在,所以在混合求和时,它的能量是**累加(放大)**的。而特异项 Bq,dB_{q,d}Bq,d:因为不同领域的特有结构方向是不一致的,当把它们加在一起时,它们不会在一个方向上叠加,反而会互相抵消,最终趋向于各向同性的背景噪声。
- 既然通用结构 SqS_qSq 的能量被放大了,而领域特异结构 Bq,dB_{q,d}Bq,d 沦为了背景噪声,那么此时对 Σ^Q\hat{\Sigma}_QΣ^Q 进行 PCA(主成分分析),提取出来的前 rrr 个最大特征值对应的特征向量(即PIDs),必然就是通用子空间 SqS_qSq 的近似。

(二)上下文路由 In-Context Routing
基于注意力路由的基础,文章提出一种新的隐式ICL方法:In-Context Routing,(ICR)框架的完整落地分为三个阶段:
1. 主要ICL方向提取 (PIDs Extraction)
- 按前面Attention routing的思路,提取出ICL的主要方向(PIDs):对D个领域,对每个领域d构建ICL提示词集合PdP_dPd,N=∑d=1D∣Dd∣N=\sum_{d=1}^D |\mathcal{D}_d|N=∑d=1D∣Dd∣是所有领域样本总数。将每一层的所有领域进行合并,注意力头进行拼接得到key和query的ICL基:Q~l=stackd=1Dstacki=1∣Pd∣qd,il∈RN×dK~l=stackd=1Dstacki=1∣Pd∣kd,il∈RN×d\widetilde{Q}^l = \operatorname{stack}_{d=1}^D \operatorname{stack}_{i=1}^{|\mathcal{P}_d|} q_{d,i}^l \in \mathbb{R}^{N \times d} \\ \widetilde{K}^l = \operatorname{stack}_{d=1}^D \operatorname{stack}_{i=1}^{|\mathcal{P}_d|} k_{d,i}^l \in \mathbb{R}^{N \times d}Ql=stackd=1Dstacki=1∣Pd∣qd,il∈RN×dKl=stackd=1Dstacki=1∣Pd∣kd,il∈RN×d
- PCA降维:对收集到的 QQQ 和 KKK 矩阵分别做PCA,取前 rrr 个主成分,得到每层的 PIDs:Uql,UklU_q^l, U_k^lUql,Ukl。这些PIDs作为可复用的路由方向被缓存。
2. Query-conditioned路由器
为了让模型根据不同的输入动态调整路由,作者设计了一个轻量级路由器(LLM主体保持冻结)。
- 输入编码:输入查询 xxx 同时过冻结的LLM和冻结的文本编码器(如MiniLM),得到语义表示 E(x)E(x)E(x)。
- 双分支MLP:E(x)E(x)E(x) 输入两个两层MLP,生成两个矩阵(L是层数):
- 路由矩阵 α(x)=tanh(gθα(E(x)))∈RL×r\alpha(x) = \tanh(g_{\theta_\alpha}(E(x))) \in \mathbb{R}^{L \times r}α(x)=tanh(gθα(E(x)))∈RL×r,控制PIDs的权重(会根据查询x的语义自适应地放大或衰减提取出的PIDs)。
- 门控矩阵 γ(x)=σ(gθγ(E(x)))∈RL×H\gamma(x) = \sigma(g_{\theta_\gamma}(E(x))) \in \mathbb{R}^{L \times H}γ(x)=σ(gθγ(E(x)))∈RL×H,控制不同注意力头的贡献(调节各个注意力头的权重)。
- 动态融合:最终第 lll 层第 hhh 个头的Attention Logits变为:
A~l,h(x)=Al,h(x)+γl,h(x)(QzslUql)diag(αl(x))(KzslUkl)⊤ \tilde{A}_{l,h}(x) = A_{l,h}(x) + \gamma_{l,h}(x) (Q_{zs}^l U_q^l) \text{diag}(\alpha^l(x)) (K_{zs}^l U_k^l)^\top A~l,h(x)=Al,h(x)+γl,h(x)(QzslUql)diag(αl(x))(KzslUkl)⊤
3. 训练对象
仅训练路由器参数 (θα,θγ)(\theta_\alpha, \theta_\gamma)(θα,θγ)。总损失函数由四部分组成:
- 监督交叉熵 LCE\mathcal{L}_{CE}LCE:提供基础的语义监督。LCE=−1B∑i=1BlogPICR(yi∣xi)\mathcal{L}_{\mathrm{CE}} = -\frac{1}{B} \sum_{i=1}^B \log P^{\mathrm{ICR}}(y_i \mid x_i)LCE=−B1i=1∑BlogPICR(yi∣xi)
- 置信度对齐 Lconf\mathcal{L}_{conf}Lconf:通过熵,让路由预测的置信度至少不低于零样本预测,避免路由模块走捷径生成不确定性过高的预测结果,同时保证路由推理不会降低预测置信度。Lconf=1B∑i=1BReLU(H(softmax(piICR))−H(softmax(pizs))),H(q)=−∑v∈Vqvlogqv\mathcal{L}_{\text{conf}} = \frac{1}{B} \sum_{i=1}^B \operatorname{ReLU}\left( H\left(\operatorname{softmax}(p_i^{\mathrm{ICR}})\right) - H\left(\operatorname{softmax}(p_i^{\mathrm{zs}})\right) \right), \quad H(q) = -\sum_{v\in\mathcal{V}} q_v \log q_vLconf=B1i=1∑BReLU(H(softmax(piICR))−H(softmax(pizs))),H(q)=−v∈V∑qvlogqv
- 稀疏路由 Lspar\mathcal{L}_{spar}Lspar & Lgate\mathcal{L}_{gate}Lgate:对 α\alphaα 和 γ\gammaγ 施加 L1L_1L1 正则化,且对越深的层施加越强的稀疏惩罚权重 wlw^lwl。这迫使深层网络依赖更少但更关键的路由方向。Lspar=Ex[1L∑l=1Lwl∥αl(x)∥1r]Lgate=Ex[1L∑l=1L∥γl(x)∥1H]\mathcal{L}_{\mathrm{spar}} = \mathbb{E}_x\left[ \frac{1}{L} \sum_{l=1}^L w^l \frac{\|\alpha^l(x)\|_1}{r} \right] \\ \mathcal{L}_{\mathrm{gate}} = \mathbb{E}_x\left[ \frac{1}{L} \sum_{l=1}^L \frac{\|\gamma^l(x)\|_1}{H} \right]Lspar=Ex[L1l=1∑Lwlr∥αl(x)∥1]Lgate=Ex[L1l=1∑LH∥γl(x)∥1]
总损失:L=LCE+λconfLconf+λsparLspar+λgateLgate\mathcal{L} = \mathcal{L}_{CE} + \lambda_{conf}\mathcal{L}_{conf} + \lambda_{spar}\mathcal{L}_{spar} + \lambda_{gate}\mathcal{L}_{gate}L=LCE+λconfLconf+λsparLspar+λgateLgate
4. 推理
推理时,当输入随机零样本提示词,路由器根据输入动态生成 A~l,h(x)\tilde{A}_{l,h}(x)A~l,h(x),并在之后用玉解码。通过这种方式,无论输入内容是否属于训练过程中见过的领域,ICR都会借助Query-conditioned组合方式,沿着共享结构方向从根本上引导注意力动态,从而让零样本推理隐性具备ICL的效果。

三、实验
(一)实验设置
- 模型:Llama2-7B, Qwen2.5-7B等(含直至70B的规模扩展实验)。
- 数据集:5个In-Domain (ID) 数据集(AGNews, SST-2, TREC, CSQA, PIQA);7个OOD数据集,分为 Near-OOD(SST-5, MR, MRPC,与ID任务类型接近)和 Far-OOD(CB, COPA, CREAK, AI2SciE,推理范式发生根本改变)。
- 基线方法:Zero-shot, Few-shot, 以及 Task Vector (TV), Function Vector (FV), ICV, I2CL, LIVE, M2IV 等隐式ICL方法。
(二)主要结果
- ID任务:ICR媲美甚至超越Few-shot,且显著优于所有隐式ICL基线。
- OOD任务:其他基线方法在OOD任务上经常发生性能崩溃(低于Zero-shot),而ICR达到 0次崩溃。在Qwen2.5-7B上,ICR比最强隐式基线提升 +6.5%,比Few-shot提升 +2.7%。
- 对比LoRA:ICR使用的可训练参数仅为LoRA的1/2到1/3,却取得了更好的整体性能,尤其在OOD泛化上优势明显。
(三)消融实验
- PCA秩 rrr:r=8r=8r=8 最优。r=4r=4r=4 瓶颈太紧,r=12r=12r=12 引入未充分训练的自由度。
- PIDs有效性:将PCA替换为随机正交基,ID性能保留但OOD性能崩溃,证明PCA提取的“方向”本身是泛化的关键。
- 损失函数:去除 Lgate\mathcal{L}_{gate}Lgate 或 Lspar\mathcal{L}_{spar}Lspar 会导致远域OOD性能下降,证明稀疏约束对避免过度干预和提升迁移性至关重要。
- 路由层位置:仅在模型后1/3层(Late层)施加路由效果最好。
256

被折叠的 条评论
为什么被折叠?



