
图1:大脑中统一且可复用的结构以及多时间尺度更新是解锁人类持续学习能力的关键组成部分。嵌套学习(NL)允许大脑的每个组件进行多时间尺度更新,同时表明像Transformer这样的知名架构实际上是具有不同频率更新的线性层。
图1翻译
统一且可复用的结构
神经可塑性是指大脑通过形成新的突触、增强或减弱现有突触、通过替代通路重新布线信号等机制来重组自身的能力。这种能力需要在全脑范围内具有统一且可复用的结构。
在嵌套学习(Nested Learning, NL)中,架构被分解为一组神经元(即线性或局部深度多层感知机MLPs),每个神经元都有其自身的上下文流和目标。该设计为学习提供了统一且可复用的结构。
多时间尺度更新
脑震荡(或脑波)对大脑协调其活动至关重要。值得注意的是,大脑并不依赖单一的集中式时钟来同步每个神经元:早期层在高频率周期中快速更新其活动,而后期层则在更长、更慢的周期中整合信息。
在NL中,每个“层级”中的参数以各自特定的频率进行更新,而不依赖于单一的集中式时钟。HOPE的设计允许早期层在高频率周期中快速更新其活动,而后期层则在更长、更慢的周期中整合信息。
图示说明:
-
左侧大脑图像标注了不同脑波频率:
- δ波(Delta Waves):0.5–4 Hz
- θ波(Theta Waves):4–8 Hz
- α波(Alpha Waves):8–12 Hz
- β波(Beta Waves):12–30 Hz
- γ波(Gamma Waves):30–100 Hz
-
右侧结构图显示从低频到最高频神经元的层级关系:
- 低频神经元 → 中频神经元 → 高频神经元 → 最高频神经元
- 每个层级连接到三个线性模块(q, k, v),表示注意力机制中的查询(Query)、键(Key)和值(Value)部分。

图2:嵌套学习范式,将机器学习模型及其训练过程表示为一组嵌套的优化问题。(左)混合架构示例。深度学习视角下的嵌套学习是扁平化的图像,无法提供关于模块中计算深度的见解,而嵌套学习则清晰地展示了所有内部梯度流。(右)神经学习模块:一种学习如何压缩自身上下文流的计算模型。例如,第一级对应模型最外层的训练,通常称为“预训练”步骤。
图2翻译
上图展示了逐层混合模型(即RNN + 自注意力机制)的嵌套学习(Nested Learning, NL)表示方式。虽然深度学习的表示方法(如背景中NL的扁平化图像)隐藏了模型内部的梯度流动,并将训练过程与架构分离,但嵌套学习使所有内部过程变得透明且可解释(白盒化)。
左侧部分:
- 展示了一个典型的深度学习模型结构,包含交替的RNN和自注意力(Attention)模块。
- 模型在训练过程中存在梯度流动(Gradient Flow),从输出端反向传播到输入端。
- 图中用不同层级的块状结构表示模型的层次性,每个层级都具有独立的梯度流动路径。
- 整体结构被简化为一个扁平化的图像,代表传统深度学习中的“黑箱”特性。
右侧部分(神经学习模块):
- 采用嵌套学习框架,将模型分解为多个层级(Level 1、Level 2、Level 3)。
- 每个层级包含一个“梯度流”模块,表示该层级内部的参数更新过程。
- 每个层级对应一组参数矩阵(M),通过优化目标函数进行更新:
- Level 1:
arg min M L ( M t ; K ( 1 ) , V ( 1 ) ) \arg\min_{\mathcal{M}} \mathcal{L}(\mathcal{M}_t; \mathbf{K}^{(1)}, \mathbf{V}^{(1)}) argMminL(Mt;K(1),V(1))
更新规则:
M t + 1 = M t − ∇ L ( M t ; K t ( 1 ) , V t ( 1 ) ) \mathcal{M}_{t+1} = \mathcal{M}_t - \nabla \mathcal{L}(\mathcal{M}_t; \mathbf{K}_t^{(1)}, \mathbf{V}_t^{(1)}) Mt+1=Mt−∇L(Mt;Kt(1),Vt(1)) - Level 2:
arg min M L ( M t ; K ( 2 ) , V ( 2 ) ) \arg\min_{\mathcal{M}} \mathcal{L}(\mathcal{M}_t; \mathbf{K}^{(2)}, \mathbf{V}^{(2)}) argMminL(Mt;K(2),V(2))
更新规则:
M t + 1 = M t − ∇ L ( M t ; K t ( 2 ) , V t ( 2 ) ) \mathcal{M}_{t+1} = \mathcal{M}_t - \nabla \mathcal{L}(\mathcal{M}_t; \mathbf{K}_t^{(2)}, \mathbf{V}_t^{(2)}) Mt+1=Mt−∇L(Mt;Kt(2),Vt(2)) - Level 3:
arg min M L ( M t ; K ( 3 ) , V ( 3 ) ) \arg\min_{\mathcal{M}} \mathcal{L}(\mathcal{M}_t; \mathbf{K}^{(3)}, \mathbf{V}^{(3)}) argMminL(Mt;K(3),V(3))
更新规则:
M t + 1 = M t − ∇ L ( M t ; K t ( 3 ) , V t ( 3 ) ) \mathcal{M}_{t+1} = \mathcal{M}_t - \nabla \mathcal{L}(\mathcal{M}_t; \mathbf{K}_t^{(3)}, \mathbf{V}_t^{(3)}) Mt+1=Mt−∇L(Mt;Kt(3),Vt(3))
- Level 1:
- 每个层级的参数更新基于其自身的上下文(K 和 V),体现了局部优化和分层学习的特点。
- 通过这种结构,嵌套学习实现了对模型内部过程的清晰可视化和可控性,提升了模型的可解释性和灵活性。
433

被折叠的 条评论
为什么被折叠?



