0% found this document useful (0 votes)
40 views58 pages

Transformers As Support Vector Machines: Davoud Ataee Tarzanagh Yingcong Li Christos Thrampoulidis Samet Oymak

(1) The authors establish a formal equivalence between the optimization of self-attention layers in transformers and support vector machine (SVM) problems, showing transformers can be interpreted as hierarchies of SVMs that separate and select optimal tokens. (2) They characterize the implicit bias of 1-layer transformers optimized with gradient descent, showing it converges in direction to minimizing the nuclear norm of the combined parameter W, equivalent to an SVM objective. Overparameterization ensures global convergence. (3) Their theory applies to linear prediction heads but they propose a more general SVM equivalence for transformers with nonlinear heads. Numerical experiments validate the findings, which provide insights into how attention layers select tokens during training.

Uploaded by

juanchin15
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
40 views58 pages

Transformers As Support Vector Machines: Davoud Ataee Tarzanagh Yingcong Li Christos Thrampoulidis Samet Oymak

(1) The authors establish a formal equivalence between the optimization of self-attention layers in transformers and support vector machine (SVM) problems, showing transformers can be interpreted as hierarchies of SVMs that separate and select optimal tokens. (2) They characterize the implicit bias of 1-layer transformers optimized with gradient descent, showing it converges in direction to minimizing the nuclear norm of the combined parameter W, equivalent to an SVM objective. Overparameterization ensures global convergence. (3) Their theory applies to linear prediction heads but they propose a more general SVM equivalence for transformers with nonlinear heads. Numerical experiments validate the findings, which provide insights into how attention layers select tokens during training.

Uploaded by

juanchin15
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as PDF, TXT or read online on Scribd
You are on page 1/ 58

Transformers as Support Vector Machines

Davoud Ataee Tarzanagh1⋆ Yingcong Li2⋆ Christos Thrampoulidis3 Samet Oymak4†

Abstract
arXiv:2308.16898v3 [cs.LG] 22 Feb 2024

Since its inception in “Attention Is All You Need”, the transformer architecture has led to revolutionary advance-
ments in natural language processing. The attention layer within the transformer admits a sequence of input tokens
X and makes them interact through pairwise similarities computed as softmax(XQK ⊤ X⊤ ), where (K, Q) are the
trainable key-query parameters. In this work, we establish a formal equivalence between the optimization geometry of
self-attention and a hard-margin SVM problem that separates optimal input tokens from non-optimal tokens using
linear constraints on the outer-products of token pairs. This formalism allows us to characterize the implicit bias of
1-layer transformers optimized with gradient descent, as follows. (1) Optimizing the attention layer, parameterized
by (K, Q), with vanishing regularization, converges in direction to an SVM solution minimizing the nuclear norm
of the combined parameter W := KQ⊤ . Instead, directly parameterizing by W minimizes a Frobenius norm SVM
objective. We characterize this convergence, highlighting that it can occur in locally-optimal directions rather than
global ones. (2) Complementing this, for W-parameterization, we prove the local/global directional convergence of
gradient descent under suitable geometric conditions. Importantly, we show that over-parameterization catalyzes global
convergence by ensuring the feasibility of the SVM problem and by guaranteeing a benign optimization landscape
devoid of stationary points. (3) While our theory applies primarily to linear prediction heads, we propose a more
general SVM equivalence that predicts the implicit bias of 1-layer transformers with nonlinear heads/MLPs. Our
findings apply to general datasets, trivially extend to cross-attention layer, and their practical validity is verified via
thorough numerical experiments. We also introduce open problems and future research directions. We believe these
findings inspire a new perspective, interpreting multilayer transformers as a hierarchy of SVMs that separates and
selects optimal tokens.

1 Introduction
Self-attention, the central component of the transformer architecture, has revolutionized natural language processing
(NLP) by empowering the model to identify complex dependencies within input sequences [VSP+ 17]. By assessing the
relevance of each token to every other token, self-attention assigns varying degrees of importance to different parts of the
input sequence. This mechanism has proven highly effective in capturing long-range dependencies, which is essential
for applications arising in NLP [KT19, BMR+ 20, RSR+ 20], computer vision [FXM+ 21, LLC+ 21, TCD+ 21, CSL+ 23],
and reinforcement learning [JLL21, CLR+ 21, WWX+ 22]. Remarkable success of the self-attention mechanism and
transformers has paved the way for the development of sophisticated language models such as GPT4 [Ope23], Bard
[Goo23], LLaMA [TLI+ 23], and ChatGPT [Ope22].
Q: Can we characterize the optimization landscape and implicit bias of transformers? How does the
attention layer select and compose tokens when trained with gradient descent?
We address these questions by rigorously connecting the optimization geometry of the attention layer and a hard
max-margin SVM problem, namely (Att-SVM), that separates and selects the optimal tokens from each input sequence.
This formalism, which builds on the recent work [TLZO23], is practically meaningful as demonstrated through
experiments, and sheds light on the intricacies of self-attention. Throughout, given input sequences X, Z ∈ RT ×d with
length T and embedding dimension d, we study the core cross-attention and self-attention models:
fcross (X, Z) := S(ZQK ⊤ X⊤ )XV, (1a)
⊤ ⊤
fself (X) := S(XQK X )XV. (1b)
1 2 3
University of Pennsylvania, [email protected]. University of California, Riverside, [email protected]. University of British Columbia,
[email protected]. 4 University of Michigan, [email protected]. ⋆ Equal contribution. † Corresponding author.

1
4 4 100 Not Local

2 2 80 Global
Local

Percentage
Percentage
60
W z1 mm W⋆mm z1
0 0
W mm z2 W⋆mm z2 40
2 2 20
0 2 3 4 5 10 20
42 0 2 4 6 8 42 0 2 4 6 8 d
Varying d

(a) W-parameterization (b) (K, Q)-parameterization


Figure 2: Percentage of different convergence
types when training cross-attention weights (W)
Figure 1: GD convergence during training of cross-attention weight using GD and varying dimension (d). Red and
W or (K, Q) with data. Teal and yellow markers represent tokens from blue bars represent the percentages of conver-
X1 and X2 , while stars mark optimal tokens. Solid lines in Figures gence to globally-optimal and locally-optimal
(a) and (b) depict Att-SVM and Att-SVM⋆ directions mapped to z1 (including global) SVM solutions, respectively.
(red) and z2 (blue), respectively. Arrows illustrating GD trajectories Teal bars are complements of the blue bars.
converging towards these SVM directions. Red and blue dotted lines Larger overparameterization (d) increases the
represent the corresponding separating hyperplanes. likelihood of global convergence.

Here, K, Q ∈ Rd×m , V ∈ Rd×v are the trainable key, query, value matrices respectively; S(·) denotes the softmax
nonlinearity, which is applied row-wise on XQK ⊤ X⊤ . Note that self-attention (1b) is a special instance of the cross-
attention (1a) by setting Z ← X. To expose our main results, suppose the first token of Z – denoted by z – is used for
prediction. Concretely, given a training dataset (Yi , Xi , zi )ni=1 with labels Yi ∈ {−1, 1} and inputs Xi ∈ RT ×d , zi ∈ Rd , we
consider the empirical risk minimization with a decreasing loss function ℓ(·) : R → R, represented as follows:
n
1X   
L(K, Q) = ℓ (Yi · f (Xi , zi )) , where f (Xi , zi ) = h Xi⊤ S Xi KQ⊤ zi . (2)
n i=1

Here, h(·) : Rd → R is the prediction head that subsumes the value weights V. In this formulation, the model f (·)
precisely represents a one-layer transformer where an MLP follows the attention layer. Note that, we recover the
self-attention in (2) by setting zi ← xi1 , where xi1 denotes the first token of the sequence Xi 1 . The softmax operation,
due to its nonlinear nature, poses a significant challenge when optimizing (2). The problem is nonconvex and nonlinear
even when the prediction head is fixed and linear. In this study, we focus on optimizing the attention weights (K, Q or
W) and overcome such challenges to establish a fundamental SVM equivalence.2
The paper’s main contributions are as follows:

• Implicit bias of the attention layer (Secs. 2-3). Optimizing the attention parameters (K, Q) with vanishing
regularization converges in direction towards a max-margin solution of (Att-SVM⋆ ) with the nuclear norm objective of
the combined parameter W := KQ⊤ (Thm 2). In the case of directly parameterizing cross-attention by the combined
parameter W, the regularization path (RP) directionally converges to (Att-SVM) solution with the Frobenius norm
objective. To our knowledge, this is the first result to formally distinguish the optimization dynamics of W vs (K, Q)
parameterizations, revealing the low-rank bias of the latter. Our theory clearly characterizes the optimality of selected
tokens (Definition 1) and naturally extends to sequence-to-sequence or causal classification settings (see SAtt-SVM and
Theorem 9 in appendix).
• Convergence of gradient descent (Secs. 4-5). Gradient descent (GD) iterates for the combined key-query variable
W converge in direction to a locally-optimal solution of (Att-SVM) with appropriate initialization and a linear head h(·)
(Sec. 5). For local optimality, selected tokens must have higher scores than their neighboring tokens. Locally-optimal
directions are not necessarily unique and are characterized in terms of the problem geometry. As a key contribution, we
identify geometric conditions that guarantee convergence to the globally-optimal direction (Sec. 4). Besides these, we
1 Note that for simplicity, we set zi = xi1 , but it can be any other row of Xi .
2We fix h(·) and only optimize the attention weights. This is partly to avoid the degenerate case where h(·) can be used to achieve zero training
loss (e.g. via standard arguments like NTK [JGH18]) without providing any meaningful insight into the functionality of the attention mechanism.

2
show that over-parameterization (i.e. dimension d being large, and equivalent conditions) catalyzes global convergence
by ensuring (1) feasibility of (Att-SVM), and, (2) benign optimization landscape, in the sense that there are no stationary
points and no spurious locally-optimal directions (see Sec. 5.2). These are illustrated in Figures 1 and 2.
• Generality of SVM equivalence (Sec. 6). When optimizing with linear h(·), the attention layer is inherently biased
towards selecting a single token from each sequence (a.k.a. hard attention). This is reflected in (Att-SVM) and arises
from output tokens being convex combinations of the input tokens. In contrast, we show that nonlinear heads necessitate
composing multiple tokens, highlighting their importance in the transformer’s dynamics (Sec. 6.1). Using insights
gathered from our theory, we propose a more general SVM equivalence. Remarkably, we demonstrate that our proposal
accurately predicts the implicit bias of attention trained by gradient descent under general scenarios not covered by
theory (e.g. h(·) being an MLP). Specifically, our general formulae decouple attention weights into two components: A
directional component governed by SVM which selects the tokens by applying a 0-1 mask, and a finite component
which dictates the precise composition of the selected tokens by adjusting the softmax probabilities.
An important feature of these findings is that they apply to arbitrary datasets (whenever SVM is feasible) and are
numerically verifiable. We extensively validate the max-margin equivalence and implicit bias of transformers through
enlightening experiments. We hold the view that these findings aid in understanding transformers as hierarchical
max-margin token-selection mechanisms, and we hope that our outcomes will serve as a foundation for upcoming
studies concerning their optimization and generalization dynamics.
Overview. The paper is structured as follows: Section 2 introduces preliminaries on self-attention and optimization.
Section 3 analyzes self-attention’s optimization geometry, showing the RP of attention parameters converges to a
max-margin solution. Sections 4 and 5 present global and local gradient descent analyses, respectively, demonstrating
convergence of W, the key-query variable, towards the solution of (Att-SVM). Section 6 provides our results on
nonlinear prediction heads and generalized SVM equivalence. Section 8 discusses relevant literature. Finally, Section 9
concludes the paper with open problems and future research directions inspired by our findings. All proofs are deferred
to the appendix.

2 Preliminaries
Unveiling the relationship between attention and linear SVMs. For lin- Logistic Attention
ear classification, it is well-established that GD iterations on logistic loss and loss (softmax)
separable datasets converge towards the hard margin SVM solution, which effec-
tively separates the two classes within the data [SHN+ 18, RZH03, ZY05]. The Gradient This
softmax nonlinearity employed by the attention layer exhibits an exponentially- Descent work
tailed behavior similar to the logistic loss; thus, attention may also be biased
towards margin-maximizing solutions. However, the attention layer operates on SVM Att-SVM
tokens within an input sequence, rather than performing classification directly.
Separates training Separates tokens of
Therefore, its bias is towards an SVM, specifically (Att-SVM), which aims examples input sequences
to separate the tokens of input sequences by selecting the relevant ones and
suppressing the rest. Nonetheless, formalizing this intuition is considerably Figure 3: Implicit biases of the atten-
more challenging: The presence of the highly nonlinear and nonconvex softmax tion layer and logistic regression.
operation renders the analysis of standard GD algorithms intricate. Additionally, the 1-layer transformer in (2) does not
inherently exhibit a singular bias towards a single (Att-SVM) problem, even when using a linear head h. Instead, it can
result in multiple locally optimal directions induced by their associated SVMs. We emphasize that [TLZO23] is the
first work to make this attention↔SVM connection. Here, we augment their framework to transformers by developing
the first guarantees for self/cross-attention layer, nonlinear prediction heads, and global convergence.
Notation. For any integer N ≥ 1, let [N] := {1, . . . , N}. We use lowercase and uppercase bold letters (e.g., a and A) to
represent vectors and matrices, respectively. The entries of a vector a are denoted as ai . For a matrix A, ∥A∥ denotes
√ i.e. maximum singular value, ∥ A∥⋆ denotes the nuclear norm, i.e. summation of all singular values,
the spectral norm,
and ∥A∥F := trace( A⊤ A) denotes the Frobenius norm. dist (·, ·) denotes the Euclidean distance between two sets.
The minimum / maximum of two numbers a, b is denoted as a ∧ b / a ∨ b. The big-O notation O(·) hides the universal
constants.
Optimization problem definition. We use a linear head h(x) = v⊤ x for most of our theoretical exposition. Given
dataset (Yi , Xi , zi )ni=1 , we minimize the empirical risk of an 1-layer transformer using combined weights W ∈ Rd×d or

3
individual weights K, Q ∈ Rd×m for a fixed head and decreasing loss function:
n
1X  
L(W) = ℓ Yi · v⊤ Xi⊤ S(XiW zi ) , (W-ERM)
n i=1
n
1X  
L(K, Q) = ℓ Yi · v⊤ Xi⊤ S(Xi KQ⊤ zi ) . (KQ-ERM)
n i=1

We can recover the self-attention model by setting zi to be the first token of Xi , i.e., zi ← xi1 . While the above
formulation regresses a single label Yi per (Xi , zi ), in Sections 6 and E, we show that our findings gracefully extend to
the sequence-to-sequence classification setting where we output and classify T tokens per inputs Xi , Z i ∈ RT ×d . See
Sections 6 and E for results on nonlinear prediction heads.
Optimization algorithms. Given a parameter R > 0, we consider an ℓ2 -norm bound R, and define the regularized path
solution associated with Objectives (W-ERM) and (KQ-ERM), respectively as (W-RP) and (KQ-RP). These update
rules allow us to find the solution within a constrained region defined by the norm bound. The RP illustrates the
evolution of W̄R as R increases, capturing the essence of GD where the ridge constraint serves as an approximation
for the number of iterations. Previous studies, including [RZH03, SPR18, JDST20, TLZO23], have examined the
implicit bias of logistic regression and established a connection between the directional convergence of the RP (i.e.,
limR→∞ W̄R /R) and GD. In [TLZO23], the concept of a local RP was also employed to investigate implicit bias along
local directions. For GD, with appropriate initialization and step size η > 0, we describe the optimization process
associated with (W-ERM) and (KQ-ERM) as (W-GD) and (KQ-GD), respectively.

Given W(0) ∈ Rd×d , η > 0, for k ≥ 0 do: Given Q(0), K(0) ∈ Rd×m , η > 0, for k ≥ 0 do:

K(k + 1)
" # " # " #
W(k + 1) = W(k) − η∇L(W(k)). (W-GD) K(k) ∇ L (K(k), Q(k))
= −η K . (KQ-GD)
Q(k + 1) Q(k) ∇Q L (K(k), Q(k))

Given R > 0, find d × d matrix: Given R > 0, find d × m matrices:

W̄R = arg min L(W). (W-RP) ( K̄R , Q̄R ) = arg min L(K, Q). (KQ-RP)
∥W∥F ≤R ∥K∥2F +∥Q∥2F ≤2R

2.1 Optimal tokens and hard-margin SVM problem for cross attention
Given Xi ∈ RT ×d , zi ∈ Rd , we present a convex hard-margin SVM problem, denoted as (Att-SVM), that aims to separate
a specific token from the remaining tokens in the input sequence Xi . This problem is jointly solved for all inputs,
allowing us to examine the optimization properties of cross-attention. To delve deeper, we introduce the concept of
optimal tokens, which are tokens that minimize the training objective under the decreasing loss function ℓ(·) as shown
in Lemma 2. This exploration will introduce the notions of token scores and optimality, providing insights into the
underlying principles of self-attention mechanisms [TLZO23].

Definition 1 (Token Score and Optimality) Given a prediction head v ∈ Rd , the score of a token xit of input Xi is
defined as γit = Yi · v⊤ xit . The optimal token for each input Xi is given by the index opti ∈ arg maxt∈[T ] γit for all i ∈ [n].

By introducing token scores and identifying optimal tokens, we can better understand the importance of individual
tokens and their impact on the overall objective. The token score quantifies the contribution of a token to the prediction
or classification task, while the optimal token represents the token that exhibits the highest relevance within the
corresponding input sequence.
• Hard-margin SVM for W-parameterization. Equipped with the set of optimal indices opt := (opti )ni=1 as per
Definition 1, we introduce the following SVM formulation associated to (W-ERM):

W mm = arg min ∥W∥F subj. to (xiopti − xit )⊤W zi ≥ 1 for all t , opti , i ∈ [n]. (Att-SVM)
W

4
The existence of matrix W mm implies the separability of tokens (opti )ni=1 from the others. The terms ait := x⊤it W zi
represent the dot product between the key-query features before applying the softmax nonlinearity. This dot product is a
crucial characteristic of self-attention and we can express the SVM constraints in (Att-SVM) as aiopti ≥ ait + 1. Thus,
(Att-SVM) finds the most efficient direction that ensures the optimal token xiopti achieves the highest similarity with
query zi among all key embeddings. Our first result shows that (Att-SVM) is feasible under mild over-parameterization.

Theorem 1 Suppose d ≥ max(T − 1, n). Then, almost all datasets (Yi , Xi , zi )ni=1 – including the self-attention setting
with zi ← xi1 – obey the following: (Att-SVM) is feasible i.e., W mm separates the desired tokens opt = (opti )ni=1 .

We note that the convex formulation (Att-SVM) does not fully capture the GD geometry on (W-ERM). In a more
general sense, GD can provably converge to an SVM solution over locally-optimal tokens, as detailed in Section 5.1.
For deeper insight into (Att-SVM), consider that the attention layer’s output is a convex mixture of input tokens. Thus,
if minimizing the training loss involves choosing the optimal token xiopti , the softmax similarities should eventually
converge to a one-hot vector, precisely including xiopti (assigned 1), while ignoring all other tokens (assigned 0s). This
convergence requires the attention weights W to diverge in norm to saturate the softmax probabilities. Due to the
exponential-tailed nature of the softmax function, the weights converge directionally to the max-margin solution. This
phenomenon resembles the implicit bias of logistic regression on separable data [SHN+ 18, JT18]. Lemma 2 formalizes
this intuition and rigorously motivates optimal tokens.
• Non-convex SVM for (K, Q)-parameterization. The objective function (KQ-ERM) has an extra layer of noncon-
vexity compared to (W-ERM) as (K, Q) corresponds to a matrix factorization of W. To study this, we introduce the
following nonconvex SVM problem over (K, Q) akin to (Att-SVM):

1 
min ∥K∥2F + ∥Q∥2F subj. to (xiopti − xit )⊤ KQ⊤ zi ≥ 1 for all t , opti , i ∈ [n]. (KQ-SVM)
K,Q 2

Even if the direction of GD is biased towards the SVM solution, it does not have to converge to the global minima
of (KQ-SVM). Instead, it can converge towards a Karush–Kuhn–Tucker (KKT) point of the max-margin SVM. Such
KKT convergence has been studied by [LL19, JT20] in the context of other nonconvex margin maximization problems.
Fortunately, (KQ-ERM) may not be as daunting as it may initially seem: Our experiments in Figures 1 and 7 reveal that
GD is indeed biased towards the global minima of (KQ-SVM). This global minima is achieved by setting W := KQ⊤
and finding the factorization of W that minimizes the quadratic objective, yielding the following W-parameterized SVM
with nuclear norm objective:

W⋆mm ∈ arg min ∥W∥⋆ subj. to (xiopti − xit )⊤W zi ≥ 1 for all t , opti , i ∈ [n]. (Att-SVM⋆ )
rank(W)≤m

Above, the nonconvex rank constraint arises from the fact that the rank of W = KQ⊤ is at most m. However,
under the condition of the full parameterization where m ≥ d, the rank constraint disappears, leading to a convex
nuclear norm minimization problem. Besides, the nuclear norm objective inherently encourages a low-rank solution
[RFP10, Faz02, SRJ04]. Lemma 1, presented below, demonstrates that this guarantee holds whenever n ≤ m. This
observation is further supported by our experiments (see Fig. 4). Thus, it offers a straightforward rationale for why
setting m < d is a reasonable practice, particularly in scenarios involving limited data.

Lemma 1 Any optimal solution of (Att-SVM) or (Att-SVM⋆ ) is at most rank n. More precisely, the row space of W mm
or W⋆mm lies within span({zi }ni=1 ).

Figure 4 illustrates rank range of solutions for (Att-SVM) and (Att-SVM⋆ ), denoted as W mm and W⋆mm , solved
using optimal tokens (opti )ni=1 and setting m = d (the rank constraint is eliminated). Each result is averaged over 100
trials, and for each trial, xit , zi , and linear head v are randomly sampled from the unit sphere. In Fig. 4(a), we fix T = 5
and vary n across {5, 10, 15}. Conversely, in Fig. 4(b), we keep n = 5 constant and alter T across {5, 10, 15}. Both
figures confirm rank of W mm and W⋆mm are bounded by max(n, d), validating Lemma 1.

5
15 5
W mm
solution

solution
SVM solution

SVM solution
W⋆mm 4
10 n=5 T=5
n = 10 3 T = 10
RankofofSVM

RankofofSVM
n = 15 T = 15
5 2 W mm
Rank

Rank
W⋆mm
1 1
1 5 10 15 20 1 5 10 15 20
Varyingdd
Varying Varyingdd
Varying
(a) Rank of attention SVM solutions with fixed T = 5 (b) Rank of attention SVM solutions with fixed n = 5

Figure 4: Rank range of solutions for (Att-SVM) and (Att-SVM⋆ ), denoted as W mm and W⋆mm , solved using optimal
tokens (opti )ni=1 and setting m = d (the rank constraint is eliminated). Both figures confirm ranks of W mm and W⋆mm are
bounded by max(n, d), validating Lemma 1.

3 Understanding the Implicit Bias of Self-Attention


We start by motivating the optimal token definition and establishing the global convergence of RPs which shed light on
the implicit bias of attention parameterizations. Throughout, we maintain the following assumption regarding the loss
function.
Assumption A Over any bounded interval [a, b]: (i) ℓ : R → R is strictly decreasing; (ii) The derivative ℓ′ is bounded
as |ℓ′ (u)| ≤ M1 ; (iii) ℓ′ is M0 -Lipschitz continuous.
Assumption A encompasses many common loss functions, e.g., logistic ℓ (u) = log (1 + e−u ), exponential ℓ (u) = e−u ,
and correlation ℓ(u) = −u losses.
Lemma 2 (Optimal Tokens Minimize Training Loss) Suppose Assumption A (i)-(ii) hold, and not all tokens are
optimal per Definition 1. Then, training risk obeys L(W) > L⋆ := 1n ni=1 ℓ(γiopti ). Additionally, suppose there are
P
optimal indices (opti )ni=1 for which (Att-SVM) is feasible, i.e. there exists a W separating optimal tokens. This W
choice obeys limR→∞ L(R · W) = L⋆ .
The result presented in Lemma 2 originates from the observation that the output tokens of the attention layer constitute
a convex combination of the input tokens. Consequently, when subjected to a strictly decreasing loss function, attention
optimization inherently leans towards the selection of a singular token, specifically, the optimal token (opti )ni=1 . Our
following theorem unveils the implicit bias ingrained within both attention parameterizations through RP analysis.

Theorem 2 Suppose Assumptions A holds, optimal indices (opti )ni=1 are unique, and (Att-SVM) is feasible. Let W mm
be the unique solution of (Att-SVM), and let W⋆mm be the solution set of (Att-SVM⋆ ) with nuclear norm achieving
objective C⋆ . Then, Algorithms W-RP and KQ-RP, respectively, satisfy:
W mm
• W-parameterization has Frobenius norm bias: lim W̄RR = ∥W mm ∥F .
R→∞
 
K̄R Q̄⊤ W⋆mm
• (K, Q)-parameterization has nuclear norm bias: lim dist R , C⋆
R
= 0.
R→∞

– Setting m = d: (Att-SVM⋆ ) is a convex problem without rank constraints.

Theorem 2 demonstrates that the RP of the (K, Q)-parameterization converges to a max-margin solution of (Att-SVM⋆ )
with nuclear norm objective on W = KQ⊤ . When self-attention is directly parameterized by W, the RP converges
to the solution of (Att-SVM) with a Frobenius norm objective. This result is the first to distinguish the optimization
dynamics of W and (K, Q) parameterizations, revealing the low-rank bias of the latter. These findings also provide a
clear characterization of token optimality (Definition 1) and extend naturally to the setting with multiple optimal tokens
per sequence (Theorem 9 in appendix). By definition, the RP captures the global geometry and cannot be used for the

6
implicit bias of GD towards locally-optimal directions. Sections 4 and 5 accomplish this goal through gradient-descent
and localized RP analysis to obtain locally-applicable SVM equivalences. Note that, this theorem requires each input
sequence has a unique optimal token per Definition 1. Fortunately, this is a very mild condition as it holds for almost all
datasets, namely, as soon as input features are slightly perturbed.
Theorem 2 establishes the implicit bias of attention from the perspective of RP analysis. This leads to the question:
To what extent is this RP theory predictive of the implicit bias exhibited by GD? To delve into this, we examine the
gradient paths of W(k) or (K(k), Q(k)) and present the findings in Figure 1. We consider a scenario where n = d = m = 2
and T = 5, and employ cross-attention, where tokens (z1 , z2 ) are generated independently of the inputs (X1 , X2 ). The
teal and yellow markers correspond to tokens from X1 and X2 , respectively. The stars indicate the optimal token for
each input. To provide a clearer view of the gradient convergence path, we illustrate the outcomes of training the
attention weight W or (K, Q) in the form of W zi or KQ⊤ zi , where i = 1, 2. With reference to Equations (Att-SVM) and
(Att-SVM⋆ ), the red and blue solid lines in Fig. 1(a) delineate the directions of W mm z1 and W mm z2 , correspondingly.
Conversely, the red and blue solid lines in Fig. 1(b) show the directions of W⋆mm z1 and W⋆mm z2 . The red/blue arrows
denote the corresponding directions of gradient evolution with the dotted lines representing the corresponding separating
hyperplanes. Figure 1 provides a clear depiction of the incremental alignment of W(k) and K(k)Q(k)⊤ with their
respective attention SVM solutions as k increases. This strongly supports the assertions of Theorem 2.
It is worth noting that (Att-SVM⋆ ) imposes a nonconvex rank constraint, i.e., rank(W) ≤ m. Nevertheless, this
constraint becomes inconsequential if the unconstrained problem, with m set to be greater than or equal to d, admits a
low-rank solution, as demonstrated in Lemma 1. Consequently, in our experimental endeavors, we have the flexibility
to employ the unconstrained attention SVM for predicting the implicit bias. This concept is succinctly summarized by
the following lemma.

Lemma 3 Let W⋆mm be the solution set of (Att-SVM⋆ ) with nuclear norm achieving objective C⋆ . Further let Wcvx
mm

be the solution set of (Att-SVM⋆ ) with m = d achieving objective Ccvx . If W⋆ ∩ Wcvx , ∅, then C⋆ = Ccvx and
mm mm

W⋆mm ⊆ Wcvx mm mm
. Also, if the elements of Wcvx have rank at most m, then, W⋆mm = Wcvx
mm
.

4 Global Convergence of Gradient Descent


In this section, we will establish conditions that guarantee the global convergence of GD. Concretely, we will investigate
when GD solution selects the optimal token within each input sequence through the softmax nonlinearity and coincides
with the solution of the RP. Section 5 will complement this with showing that self-attention can more generally
converge to locally-optimal max-margin directions. We identify the following conditions as provable catalysts for
global convergence: (i) Optimal tokens have relatively large scores; (ii) Initial gradient direction is favorable; (iii)
Overparameterization, i.e. d is appropriately large.

4.1 Properties of optimization landscape


We start by establishing some fundamental properties of Objectives (W-ERM) and (KQ-ERM).
Lemma 4 Under Assumption A, ∇L(W), ∇ K L(K, Q), and ∇Q L(K, Q) are LW , L K , LQ –Lipschitz continuous, respec-
tively, where ai = ∥v∥ ∥zi ∥2 ∥Xi ∥3 , bi = M0 ∥v∥ ∥Xi ∥ + 3M1 for all i ∈ [n],
n
1X
LW := ai bi , L K := ∥Q∥LW , and LQ := ∥K∥LW . (3)
n i=1

The next assumption will play an important role ensuring the attention layer has a benign optimization landscape.
Assumption B Optimal tokens’ indices (opti )ni=1 are unique and one of the following conditions on the tokens holds:
B.1 All tokens are support vectors, i.e., (xiopti − xit )⊤W mm zi = 1 for all t , opti and i ∈ [n].
B.2 The tokens’ scores, as defined in Definition 1, satisfy γit = γiτ < γiopti , for all t, τ , opti and i ∈ [n].
Assumption B.1 is directly linked to overparameterization and holds practical significance. In scenarios such as
classical SVM classification, where the goal is to separate labels, overparameterization leads to the situation where all
training points become support vectors. Consequently, the SVM solution aligns with the least-squares minimum-norm

7
100 Not Local 100 Not Local
80 Global
Local
80 Global
Local
Percentage

Percentage
Probabilities

Probabilities
60 Assum B.1 60 Assum B.1

40 40
20 20
0 2 3 4 5 10 20 40 0 2 3 4 5 10 20 40
d d
Varying d Varying d
(a) Convergence behaviour of GD for W-parameterization (b) Convergence behaviour of GD for (K, Q)-parameterization

Figure 5: Percentage of different convergence types of GD when training cross-attention weights (a): W or (b): (K, Q)
with varying d. In both figures, red, blue, and teal bars represent the percentages of Global, Local (including Global),
and Not Local convergence, respectively. The green bar corresponds to Assumption B.1 where all tokens act as support
vectors. Larger overparameterization (d) relates to a higher percentage of globally-optimal SVM convergence.

interpolation, a concept established in [MNS+ 21, HMX21] under broad statistical contexts. Assumption B.1 represents
an analogous manifestation of this condition. Therefore, in cases involving realistic data distributions with sufficiently
large d, we expect the same phenomena to persist, causing the SVM solution W mm to coincide with (Att-SVM).
Drawing on insights from [HMX21, Theorem 1] and our Theorem 1, we expect that the necessary degree of
overparameterization remains moderate. Specifically, in instances where input sequences follow an independent and
identically distributed (IID) pattern and tokens exhibit IID isotropic distributions, we posit that d ≳ (T + n) log(T + n)
will suffice. More generally, the extent of required overparameterization will be contingent on the covariance of tokens
[BLLT20, MNS+ 21] and the distribution characteristics of input sequences [WT22].
Assumption B.2 stipulates that non-optimal tokens possess identical scores which constitutes a relatively stringent
assumption that we will subsequently relax. Under Assumption B, we establish that when optimization problem
(W-ERM) is trained using GD, the norm of parameters will diverge.

Theorem 3 Suppose Assumption A on the loss function ℓ and Assumption B on the tokens hold.

• There is no W ∈ Rd×d satisfying ∇L(W) = 0.


• Algorithm W-GD with the step size η ≤ 1/LW and any starting point W(0) satisfies limk→∞ ∥W(k)∥F = ∞.

The feasibility of SVM (per Theorem 1) is a necessary condition for the convergence of GD to the W mm direction.
However, it does not inform us about the optimization landscape. Two additional criteria are essential for convergence:
the absence of stationary points ∇L(W) = 0 and divergence of parameter norm to infinity. Theorem 3 above precisely
guarantees both of these criteria under Assumption B.

4.2 Provable global convergence of 1-layer transformer


In the quest for understanding the global convergence of a 1-layer transformer, [TLZO23, Theorem 2] provided the first
global convergence analysis of the attention in a restrictive scenario where n = 1 and under the assumption B.2. Here,
we present two new conditions for achieving global convergence towards the max-margin direction W mm based on: (I)
the initial gradient direction, and (II) over-parameterization. For the first case, we provide precise theoretical guarantees.
For the second, we offer strong empirical evidence, supported by Theorem 3, and a formal conjecture described in
Section 5.2. We remind the reader that we optimize the attention weights W while fixing the linear prediction head
h(x) = v⊤ x. This approach avoids trivial convergence guarantees where an over-parameterized h(·)—whether it is a
linear model or an MLP—can be used to achieve zero training loss without providing any meaningful insights into the
functionality of the attention mechanism.

8
1.0 1.0

global convergence

global convergence
convergence

convergence
0.8 0.9
0.6 W W
0.8
ofofglobal

ofofglobal
(K, Q) (K, Q)
0.4
(n, T) = (5, 5) 0.7 (n, T) = (5, 5)
0.2
Prob.

Prob.
(n, T) = (10, 5) (n, T) = (5, 10)
Prob.

Prob.
0.0 (n, T) = (20, 5) 0.6 (n, T) = (5, 20)
5 10 15 20 25 30 5 10 15 20 25 30
Varying dd
Varying Varying dd
Varying
(a) Global convergence for varying n, d (b) Global convergence for varying T, d

Figure 6: Global convergence behavior of GD when training cross-attention weights W (solid) or (K, Q) (dashed) with
random data. The blue, green, and red curves represent the probabilities of global convergence for (a): fixing T = 5 and
varying n ∈ {5, 10, 20} and (b): fixing n = 5 and varying T ∈ {5, 10, 20}. Results demonstrate that for both attention
models, as d increases (due to over-parameterization), attention weights tend to select optimal tokens (opti )ni=1 .

(I) Global convergence under good initial gradient. To ensure global convergence, we identify an assumption that
prevents GD from getting trapped at suboptimal tokens that offer no scoring advantage compared to other choices.
To establish a foundation for providing the convergence of GD to the globally optimal solution W mm , we need the
following definitions. For parameters µ > 0 and R > 0, define
( * + )
W
C̄µ,R := ∥W∥F ≥ R (xiopti − xit )zi ,

≥ µ for all t , opti , i ∈ [n] . (4)
∥W∥F

This is the set of all Ws that separate the optimal tokens from the rest with margin µ. We will show that, for any
µ > 0, the optimization landscape of this set is favorable and, if the updates remain in the set, the gradient descent will
maximize the margin and find W mm .

Assumption C (First GD Step is Separating) For some ι > 0 and all t , opti , i ∈ [n]: (xit − xiopti )⊤ ∇L(0)zi ≥ ι.

Theorem 4 Suppose Assumption A on the loss function ℓ and Assumption C on the initial gradient hold.
• For any µ > 0, there exists R > 0 such that C̄µ,R does not contain any stationary points.
• Fix any µ ∈ (0, ι/∥∇L(0)∥F ). Consider GD iterations with W(0) = 0, W(1) = −R∇ L(0)/∥∇L(0)∥F , and
W(k + 1) = W(k) − η∇L(W(k)) for k ≥ 1, where η ≤ 1/LW and R sufficiently large. If all iterates remain within
mm
C̄µ,R , then limk→∞ ∥W(k)∥F = ∞ and limk→∞ ∥W(k)∥
W(k)
F
= ∥WWmm ∥F .

Note that the second result of Theorem 4, i.e., the divergence of parameter norm to infinity and directional con-
vergence requires that all GD iterations
D remain within C̄µ,R Edefined in (4).
D In Appendix EC.2, we show that if
for all W ∈ C̄µ,R (W ), mini∈[n] (xiopti − xit )zi , W − η∇L(W) − mini∈[n] (xiopti − xit )z⊤i , W is lower bounded by
mm ⊤

(2ηµ/∥W mm ∥F ) ⟨−∇L(W), W mm ⟩, then all GD iterations W(k) remain within C̄µ,R . While this condition may appear
complicated, it is essentially a tight requirement for updates to remain within C̄µ,R . Finally, it is worth mentioning that, if
a stronger correlation condition between initial gradient ∇L(0) and W mm holds, one can also prove that updates remain
within a tighter cone around W mm through ideas developed in Theorem 5 by landing W(1) around W mm direction.
However, we opt to state here the result for the milder condition C̄µ,R .

(II) Global convergence via overparameterization. In the context of standard neural networks, overparameterization
has been recognized as a pivotal factor for the global convergence of GD [DZPS18, AZLS19, LL18, OS19]. However,
conventional approaches like the neural tangent kernel [JGH18] do not seamlessly apply to our scenario, given our
assumption on fixed h and the avoidance of achieving zero loss by trivially fitting h. Furthermore, even when we train h
to achieve zero loss, it doesn’t provide substantial insights into the implicit bias of the attention weights W. Conversely,

9
1.0 1.0 1.0

Correlationcoefficient

Correlationcoefficient
coefficient

coefficient
Softmaxprobability
probability

0.8 0.8 0.8

0.6 0.6 0.6

Correlation

Correlation
0.4 0.4
Softmax

0.4 W W W
0.2 0.2
0.2 (K, Q) (K, Q) (K, Q)
0 100 200 300 400 500 600 0.0 0 100 200 300 400 500 600 0.0 0 100 200 300 400 500 600
Iterations
Iterations Iterations
Iterations Iterations
Iterations
(a) Evolution of softmax probabilities (b) Corr. coeff. of GD and Wαmm mm
(c) Corr. coeff. of GD and W⋆,α

Figure 7: Local convergence behaviour of GD when training cross-attention weights W (blue) or (K, Q) (red) with
random data: (a) displays the largest entry of the softmax outputs averaged over the dataset; (b&c) display the
Pearson correlation coefficients of GD trajectories and the SVM solutions (b) with the Frobenius norm objective Wαmm
mm
(solution of (Att-SVM)) and (c) with the nuclear norm objective W⋆,α (solution of (Att-SVM⋆ )). These demonstrate
the Frobenius norm bias of W(k) and the nuclear norm bias of K(k)Q(k)⊤ .

Theorem 3 illustrates the benefits of over-parameterization in terms of convergence. Considering that Assumption B.1
is anticipated to hold as the dimension d increases, the norm of the GD solution is bound to diverge to infinity. This
satisfies a prerequisite for converging towards the globally-optimal SVM direction W mm .
The trend depicted in Figure 5, where the percentage of global convergence (red bars) approaches 100% and
Assumption B.1 holds with higher probability (green bars) as d grows, reinforces this insight. Specifically, Fig. 5(a)
is similar to Figure 2 but with additional green bars representing the percentage of the scenarios where almost all
tokens act as support vectors (Assumption B.1), and Fig. 5(b) displays the same evaluation over (K, Q)-parameterization
setting. In both experiments, and for each chosen d value, a total of 500 random instances are conducted under the
conditions of n = T = 5. The outcomes are reported in terms of the percentages of Not Local, Local, and Global
convergence, represented by the teal, blue, and red bars, respectively. We validate Assumption B.1 as follows: Given a
problem instance, we compute the average margin over all non-optimal tokens of all inputs and declare that problem
satisfies Assumption B.1, if the average margin is below 1.1 (where 1 is the minimum).
Furthermore, the observations in Figure 6 regarding the percentages of achieving global convergence reaching
100 with larger d reaffirm that overparameterization leads the attention weights to converge directionally towards the
optimal max-margin direction outlined by (Att-SVM) and (Att-SVM⋆ ).
In the upcoming section, we will introduce locally-optimal directions, to which GD can be proven to converge when
appropriately initialized. We will then establish a condition that ensures the globally-optimal direction is the sole viable
locally-optimal direction. This culmination will result in a formal conjecture detailed in Section 5.2.

5 Understanding Local Convergence of 1-Layer Transformer


So far, we have primarily focused on the convergence to the global direction dictated by (Att-SVM). In this section, we
investigate and establish the local directional convergence of GD as well as RP.

5.1 Local convergence of gradient descent


To proceed, we introduce locally-optimal directions by adapting Definition 2 of [TLZO23].

Definition 2 (Support Indices and Locally-Optimal Direction) Fix token indices α = (αi )ni=1 . Solve (Att-SVM) with
(opti )ni=1 replaced with α = (αi )ni=1 to obtain Wαmm . Consider the set Ti ⊂ [T ] such that (xiαi − xit )⊤Wαmm zi = 1 for all
t ∈ Ti . We refer to (Ti )ni=1 as the support indices of α. Additionally, if for all i ∈ [n] and t ∈ Ti scores per Definition 1
obey γiαi > γit , indices α = (αi )ni=1 are called locally-optimal and Wαmm is called a locally-optimal direction.

In words, the concept of local optimality requires that the selected tokens denoted as α should have scores that
are higher than the scores of their neighboring tokens referred to as support indices. It is important to observe that

10
the tokens defined as opt = (opti )ni=1 , which we term as the optimal tokens, inherently satisfy the condition of local
optimality. Moving forward, we will provide Theorem 5 which establishes that when the process of GD is initiated
along a direction that is locally optimal, it gradually converges in that particular direction, eventually aligning itself
with Wαmm . This theorem immediately underscores the fact that if there exists a direction of local optimality (apart from
the globally optimal direction W mm ), then when GD commences from any arbitrary starting point, it does not achieve
global convergence towards W mm .
To provide a basis for discussing local convergence of GD, we establish a cone centered around Wαmm using the
following construction. For parameters µ ∈ (0, 1) and R > 0, we define Cµ,R (Wαmm ) as the set of matrices W ∈ Rd×d such
that ∥W∥F ≥ R and the correlation coefficient between W and Wαmm is at least 1 − µ:

Wαmm
( * + )
W
Cµ,R (Wα ) = ∥W∥F ≥ R
mm
, ≥1−µ . (5)
∥W∥F ∥Wαmm ∥F
Theorem 5 Suppose Assumption A on the loss ℓ holds, and let α = (αi )ni=1 be locally optimal tokens according to
Definition 2. Let Wαmm denote the SVM solution obtained via (Att-SVM) by replacing (opti )ni=1 with α = (αi )ni=1 .
• There exist parameters µ = µ(α) ∈ (0, 1) and R > 0 such that Cµ,R (Wαmm ) does not contain any stationary points.

• Algorithm W-GD with η ≤ 1/LW and any W(0) ∈ Cµ,R (Wαmm ) satisfies limk→∞ ∥W(k)∥F = ∞ and limk→∞ W(k)
∥W(k)∥F =
Wαmm
∥Wαmm ∥F .

This theorem establishes the existence of positive parameters µ = µ(α) > 0 and R > 0 such that there are no stationary
points within Cµ,R (Wαmm ). Furthermore, if GD is initiated within Cµ,R (Wαmm ), it will converge in the direction of
Wαmm /∥Wαmm ∥F . It is worth mentioning that stronger Theorem 3 (e.g. global absence of stationary points) is applicable
whenever all tokens are support i.e. T̄i = ∅ for all i ∈ [n].
In Figure 7, we consider setting where n = 6, T = 8, and d = 10. The displayed results are averaged from 100
random trials. We train cross-attention models with xit , zi , v ∈ Rd randomly sampled from unit sphere, and apply
the normalized GD approach with fixed step size η = 0.1. In Figure 7(a) we calculate the softmax probability via
i=1 maxt∈[T ] S(Xi W̃(k)zi )t for either W̃ = W or KQ at each iteration. Both scenarios result in probability 1, which
1 Pn ⊤
n
indicates that attention weights succeed in selecting one token per input. Then following Definition 2 let α = (αi )ni=1
mm
be the token indices selected by GD and denote W⋆,α as the corresponding SVM solution of (Att-SVM⋆ ). Define
the correlation coefficient of two matrices as corr_coef(W1 , W2 ) := ⟨W1 , W2 ⟩ /∥W1 ∥F ∥W2 ∥F . Figures 7(b) and 7(c)
illustrate the correlation coefficients of attention weights (W(k) and K(k)Q(k)⊤ ) with respect to Wαmm and W⋆,α mm
. The
results demonstrate that W (KQ ) ultimately reaches a 1 correlation with Wα (W⋆,α ), which suggests that W (KQ⊤ )
⊤ mm mm

converges in the direction of Wαmm (W⋆,α mm


). This further validates Theorem 5.

5.2 Overparameterization conjecture: When local-optimal directions disappear


In Section 4 we demonstrated that larger d serves as a catalyst for global convergence to select the optimal indices
opt = (opti )ni=1 . However, Section 5.1 shows that the convergence can be towards locally-optimal directions rather
than global ones. How do we reconcile these? Under what precise conditions, can we expect global convergence?
The aim of this section is gathering these intuitions and stating a concrete conjecture on the global convergence of
the attention layer under geometric assumptions related to overparameterization. To recap, Theorem 1 characterizes
when (Att-SVM) is feasible and Theorem 3 characterizes when the parameter norm provably diverges to infinity, i.e.
whenever all tokens are support vectors of (Att-SVM) (Assumption B.1 holds). On the other hand, this is not sufficient
for global convergence, as GD can converge in direction to locally-optimal directions per Section 5.1. Thus, to guarantee
global convergence, we need to ensure that globally-optimal direction is the only viable one. Our next assumption is
a fully-geometric condition that precisely accomplishes this.

Assumption D (There is always an optimal support index) For any choice of α = (αi )ni=1 with α , opt when
solving (Att-SVM) with opt ← α, there exists i ∈ [n] such that αi , opti and opti ∈ Ti .

This guarantees that no α , opt can be locally-optimal because it has a support index with higher score at the input
i with αi , opti . Thus, this ensures that global direction W mm is the unique locally-optimal direction obeying Def. 2.
Finally, note that local-optimality in Def. 2 is one-sided: GD can provably converge to locally-optimal directions, while
we do not provably exclude the existence of other directions. Yet, Theorem 4 of [TLZO23] shows that local RPs (see

11
100 100
Global

SVM margin
Local

10 1 10 1

1 1 5 5 10 10 15 15
2 2 5 5 10 10 15 15
Varying n Varying d
100 100
Global
80 80
Percentage

Local
60 60
40 40
20 20
0 01 1 5 5 10 10 15 15 2 2 5 5 10 10 15 15
Varying n Varying d

Figure 8: Performance of GD convergence and corresponding SVM margin. Upper: The SVM margins correspond to
globally-optimal (red) and locally-optimal (blue) token indices, denoted as 1/∥W mm ∥F and 1/∥Wαmm ∥F , respectively.
Lower: Percentages of global convergence (when α = opt, red) and local convergence (when α , opt, blue).

Section 5.4 for details) can only converge to locally-optimal directions for almost all datasets3 . This and Figure 5
provide strong evidence that Def. 2 captures all possible convergence directions of GD, and as a consequence, that
Assumption D guarantees that W mm is the only viable direction to converge.
➡ Integrating the results and global convergence conjecture. Combining Assumptions B.1 and D, we have concluded
that gradient norm diverges and W mm is the only viable direction to converge. Thus, we conclude this section with the
following conjecture: Suppose opt = (opti )ni=1 are the unique optimal indices with strictly highest score per sequence
and Assumptions B.1 and D hold. Then, for almost all datasets (e.g. add small IID gaussian noise to input features), GD
with a proper constant learning rate converges to W mm of (Att-SVM) in direction.
To gain some intuition, consider Figure 5: Here, red bars denote the frequency of global convergence whereas
green bars denote the frequency of Assumption B.1 holding over random problem instances. In short, this suggests
that Assumption B.1 occurs less frequently than global convergence, which is consistent with our conjecture. On the
other hand, verifying Assumption D is more challenging due to its combinatorial nature. A stronger condition that
implies Assumption D is when all optimal indices (opti )ni=1 are support vectors of the SVM. That is, either opti = αi
or opti ∈ Ti , ∀ i ∈ [n]. When the data follows a statistical model, this stronger condition could be verified through
probabilistic tools building on our earlier “all training points are support vectors” discussion [MNS+ 21]. More generally,
we believe a thorough statistical investigation of (Att-SVM) is a promising direction for future work.

5.3 Investigation on SVM objectives and GD convergence


Until now, we have discussed the global and local convergence performances of gradient descent (GD). Theorem 5
suggests that, without specific restrictions on tokens, when training with GD, the attention weight W converges towards
Wαmm . Here, the selected token indices α = (αi )ni=1 may not necessarily be identical to opt = (opti )ni=1 . Experiments
presented in Figures 2, 5, and 6 also support this observation. In this section, we focus on scenarios where α , opt (e.g.,
when W mm is not feasible) and investigate the question: Towards which local direction is GD most likely to converge?
To this goal, in Figure 8, we consider SVM margin, which is defined by 1/∥Wαmm ∥F , and investigate its connection
to the convergence performance of GD. On the left, we set T = d = 5 and vary n among 1, 5, 10, 15; on the right, we fix
T = n = 5 and change d to 2, 5, 10, 15. All tokens are randomly generated from the unit sphere. The SVM margins
3 To be precise, they prove this for their version of Def. 2, which is stated for an attention model f (p) = v⊤ X ⊤ S(XW p) admitting an analogous

SVM formulation.

12
corresponding to the selected tokens α are depicted as blue curves in the upper subfigures, while the SVM margins
corresponding to the globally–optimal token indices (α = opt) are shown as red curves. The red and blue bars in the
lower subfigures represent the percentages of global and local convergence, respectively. Combining all these findings
empirically demonstrates that when the global SVM objective yields a solution with a small margin (i.e., 1/∥W mm ∥F is
small, and 0 when global SVM is not feasible), GD tends to converge towards a local direction with a comparatively
larger margin.

5.4 Guarantees on local regularization path


In this section, we provide a localized regularization path analysis for general objective functions. As we shall see, this
will also allow us to predict the local solutions of gradient descent described in Section 5.1. Let ⋄ denote a general
norm objective. Given indices α = (αi )ni=1 , consider the formulation
mm
W⋄,α = arg min ∥W∥⋄ subj. to (xiαi − xit )⊤W zi ≥ 1 for all t , αi , i ∈ [n]. (⋄-SVM)
rank(W)≤m

In this section, since ⋄ is clear from the context, we will use the shorthand Wαmm := W⋄,α mm
and denote the optimal
mm
solution set of (⋄-SVM) as W := W⋄,α . It is important to note that if the ⋄-norm is not strongly convex, Wmm
mm

may not be a singleton. Additionally, when m = d, the rank constraint becomes vacuous, and the problem becomes
convex. The following result is a slight generalization of Theorem 1 and demonstrates that choosing a large d ensures
the feasibility of (⋄-SVM) uniformly over all choices of α. The proof is similar to that of Theorem 1, as provided in
Appendix A.
Theorem 6 Suppose d ≥ max(T − 1, n) and m = d. Then, almost all datasets4 (Yi , Xi , zi )ni=1 – including the self-
attention setting with zi ← xi1 – obey the following: For any choice of indices α = (αi )ni=1 ⊂ [T ], (⋄-SVM) is feasible,
i.e. the attention layer can separate and select indices α.

To proceed, we define the local regularization path, which is obtained by solving the ⋄-norm-constrained problem
over a α-dependent cone denoted as coneϵ (α). This cone has a simple interpretation: it prioritizes tokens with a lower
score than α over tokens with a higher score than α. This interpretation sheds light on the convergence towards locally
optimal directions: lower-score tokens create a barrier for α and prevent optimization from moving towards higher-score
tokens.

Definition 3 (Low&High Score Tokens and Separating Cone) Given α ∈ [T ], input sequence X with label Y, h(·) :
Rd → R, and score γt = Y · h(xt ) for all t ∈ [T ], define the low and high score tokens as

lowα (X) := t ∈ [T ] γt < γα , highα (X) := t ∈ [T ] − {α} γt ≥ γα .


n o n o

For input Xi and index αi , we use the shorthand notations lowαi and highαi . Finally define coneϵ (α) as
( )
coneϵ (α) := rank(W) ≤ m min maxα min α (xit − xiτ )⊤W zi ≥ ϵ∥W∥F . (6)
i∈[n] t∈lowi τ∈highi

Our next lemma relates this cone definition to locally-optimal directions of Definition 2.
Lemma 5 Suppose (⋄-SVM) is feasible. If indices α are locally-optimal, Wαmm ∈ coneϵ (α) for all sufficiently small
ϵ > 0. Otherwise, Wαmm < coneϵ (α) for all ϵ > 0. Additionally, suppose optimal indices opti ∈ arg maxt∈[T ] γit are
unique and set α ← opt. Then, coneϵ (opt) is the set of all rank-≤m matrices (i.e. global set).

Lemma 5 can be understood as follows: Among the SVM solutions Wαmm , only those that are locally optimal
demonstrate a barrier of low-score tokens, effectively acting as a protective shield against higher-score tokens. Moreover,
in the case of globally optimal tokens (with the highest scores), the global set coneϵ (opt) can be chosen, as they
inherently do not require protective measures. The subsequent result introduces our principal theorem, which pertains to
the regularization path converging towards the locally-optimal direction over coneϵ (α) whenever α is locally optimal.
4 Here, “almost all datasets” means that adding i.i.d. gaussian noise, with arbitrary nonzero variance, to the input features will almost surely result

in SVM’s feasibility.

13
Theorem 7 (Convergence of Local Regularization Path) Suppose Assumption A holds. Fix locally-optimal token
indices α = (αi )ni=1 and R0 , ϵ > 0. Consider the norm-constrained variation of (6) defined as
\n o
C⋄ϵ,R0 := coneϵ (α) W ∥W∥⋄ ≥ R0 .

Define local RP as W̄R = minC⋄ϵ,R ,∥W∥⋄ ≤R L(W) where L(W) is given by (W-ERM). Let Wmm be the set of minima for
0
(⋄-SVM) and Ξ⋄ > 0 be the associated margin i.e. Ξ⋄ = 1/∥Wαmm ∥⋄ . For any sufficiently small ϵ > 0 and sufficiently
large R0 = O(1/ϵ) > 0, limR→∞ dist RΞ W̄R

, Wmm = 0. Additionally, suppose optimal indices opt = (opti )ni=1 are
unique and set α ← opt. Then, the same convergence guarantee on regularization path holds by setting C⋄ϵ,R0 as the set
of rank-≤m matrices.

Note that when setting m = d, the rank constraint is eliminated. Consequently, specializing this theorem to the
Frobenius norm aligns it with Theorem 5. On the other hand, by assigning ⋄ as the nuclear norm and α ← opt, the
global inductive bias of the nuclear norm is recovered, as stated in Theorem 2.
We would like to emphasize that both this theorem and Theorem 2 are specific instances of Theorem 9 found in
Appendix E. It is worth noting that, within this appendix, we establish all regularization path results for sequence-to-
sequence classification, along with a general class of monotonicity-preserving prediction heads outlined in Assumption
E. The latter significantly generalizes linear heads, highlighting the versatility of our theory. The following section
presents our discoveries concerning general nonlinear heads.

6 Toward A More General SVM Equivalence for Nonlinear Prediction Heads


So far, our theory has focused on the setting where the attention layer selects a single optimal token within each
sequence. As we have discussed, this is theoretically well-justified under linear head assumption and certain nonlinear
generalizations. On the other hand, for arbitrary nonconvex h(·) or multilayer transformer architectures, it is expected
that attention will select multiple tokens per sequence. This motivates us to ask:
Q: What is the implicit bias and the form of W(k) when the GD solution is composed by multiple tokens?
In this section, our goal is to derive and verify the generalized behavior of GD. Let oi = Xi⊤ sW i denote the composed
token generated by the attention layer where sW i = S(X i W zi ) are the softmax probabilities corresponding to W. Suppose
GD trajectory converges to achieve the risk L⋆ = minW L(W), and the eventual token composition achieving L⋆ is
given by
o⋆i = Xi⊤ s⋆i ,
where s⋆i are the eventual softmax probability vectors that dictate the token composition. Since attention maps are
sparse in practice, we are interested in the scenario where s⋆i is sparse i.e. it contains some zero entries. This can only
be accomplished by letting ∥W∥F → ∞. However, unlike the earlier sections, we wish to allow for arbitrary s⋆i rather
than a one-hot vector which selects a single token.
To proceed, we aim to understand the form of GD solution W(k) responsible for composing o⋆i via the softmax map
s⋆i as ∥W∥F → ∞. Intuitively, W(k) should be decomposed into two components via

W(k) ≈ W fin + ∥W(k)∥F · W̄ mm , (7)

where W fin is the finite component and W̄ mm is the directional component with ∥W̄ mm ∥F = 1. Define the selected set
Oi ⊆ [T ] to be the indices s⋆it , 0 and the masked (i.e. suppressed) set as Ōi = [T ] − Oi where softmax entries are zero.
In the context of earlier sections, we could also call these the optimal set and the non-optimal set, respectively.
• Finite component: The job of W fin is to assign nonzero softmax probabilities within each s⋆i . This is accomplished
by ensuring that, W fin induces the probabilities of s⋆i over Oi by satisfying the softmax equations
⊤ fin
e xit W zi
= s⋆it /s⋆iτ ,

W fin zi
x⊤ fin
= e(xit −xiτ )
e iτ W zi

for t, τ ∈ Oi . Consequently, this W fin should satisfy the following linear constraints

(xit − xiτ )⊤W fin zi = log(s⋆it /s⋆iτ ) for all t, τ ∈ Oi , i ∈ [n]. (8)

14
coefficient d=4 d=6 d=8 d = 10
100 SVMeq
W SVMeq
1-correlation coefficient
W
W mm W mm
W 1token W 1token
10 1
1−correlation

W SVMeq W SVMeq
W mm W mm
10 2
W 1token W 1token
0 500 1000 1500 2000 0 500 1000 1500 2000 0 500 1000 1500 2000 0 500 1000 1500 2000
Iterations
Iterations Iterations
Iterations Iterations
Iterations Iterations
Iterations
coefficient
1-correlation coefficient

10 1

10 2
1−correlation

10 3

10 4
10 15 10 12 10 9 10 6 10 15 10 12 10 9 10 6 10 15 10 12 10 9 10 6 10 15 10 12 10 9 10 6
Softmax
maxi,τ siτthreshold
, τ ∈ Ōi Softmax
maxi,τ siτthreshold
, τ ∈ Ōi Softmax
maxi,τ siτthreshold
, τ ∈ Ōi Softmax
maxi,τ sthreshold
iτ , τ ∈ Ōi

Figure 9: Behavior of GD with nonlinear nonconvex prediction head and multi-token compositions. Upper: The
correlation between GD solution and three distinct baselines: (· · · ) W mm obtained from (Gen-SVM); (—) W SVMeq
obtained by calculating W fin and determining the best linear combination W fin + γW̄ mm that maximizes correlation with
the GD solution; and (- -) W 1token obtained by solving (Att-SVM) and selecting the highest probability token from the
GD solution. Lower: Scatterplot of the largest softmax probability over masked tokens (per our siτ ≤ 10−6 criteria) vs
correlation coefficient.

• Directional component: While W fin creates the composition by allocating the nonzero softmax probabilities, it
does not explain sparsity of attention map. This is the role of W̄ mm , which is responsible for selecting the selected
tokens Oi and suppressing the masked ones Ōi by assigning zero softmax probability to them. To predict direction
component, we build on the theory developed in earlier sections. Concretely, there are two constraints W̄ mm should
satisfy
1. Equal similarity over selected tokens: For all t, τ ∈ Oi , we have that (xit − xiτ )⊤W zi = 0. This way, softmax
scores assigned by W fin are not disturbed by the directional component and W fin + R · W̄ mm will still satisfy the
softmax equations (8).
2. Max-margin against masked tokens: For all t ∈ Oi , τ ∈ Ōi , enforce the margin constraint (xit − xiτ )⊤W zi ≥ 1
subject to minimum norm ∥W∥F .
Combining these yields the following convex generalized SVM formulation


∀ t ∈ Oi , τ ∈ Ōi : (xit − xiτ )⊤W zi ≥ 1,

mm
= arg min ∥W∥F

W subj. to ∀1 ≤ i ≤ n. (Gen-SVM)
W ∀ t, τ ∈ Oi :

 (xit − xiτ )⊤W zi = 0,

and set the normalized direction in (7) to W̄ mm = W mm /∥W mm ∥F .


It is important to note that (Gen-SVM) offers a substantial generalization beyond the scope of the previous sections,
where the focus was on selecting a single token from each sequence, as described in the main formulation (Att-SVM).
This broader solution class introduces a more flexible approach to the problem.
We present experiments showcasing the predictive power of the (Gen-SVM) equivalence in nonlinear scenarios. We
conducted these experiments on random instances using an MLP denoted as h(·), which takes the form of 1⊤ ReLU(x).
We begin by detailing the preprocessing step and our setup. For the attention SVM equivalence analytical prediction,
clear definitions of the selected and masked sets are crucial. These sets include token indices with nonzero and zero
softmax outputs, respectively. However, practically, reaching a precisely zero output is not feasible. Hence, we define

15
the selected set as tokens with softmax outputs exceeding 10−3 , and the masked set as tokens with softmax outputs
below 10−6 . We also excluded instances with softmax outputs falling between 10−6 and 10−3 to distinctly separate the
concepts of selected and masked sets, thereby enhancing the predictive accuracy of the attention SVM equivalence.
In addition to the filtering process, we focus on scenarios where the label Y = −1 exists to enforce non-convexity of
prediction head Yi · h(·). It is worth mentioning that when all labels are 1, due to the convexity of Yi · h(·), GD tends
to select one token per input, and Equations (Gen-SVM) and (Att-SVM) yield the same solutions. The results are
displayed in Figure 9, where n = 3, T = 4, and d varies within 4, 6, 8, 10. We conduct 500 random trials for different
choices of d, each involving xit , zi , and v randomly sampled from the unit sphere. We apply normalized GD with a step
size η = 0.1 and run 2000 iterations for each trial.

• Figure 9 (upper) illustrates the correlation evolution between the GD solution and three distinctive baselines: (· · · )
W mm obtained from (Gen-SVM); (—) W SVMeq obtained by calculating W fin and determining the best linear combination
W fin + γW̄ mm that maximizes correlation with the GD solution; and (- -) W 1token obtained by solving (Att-SVM)
and selecting the highest probability token from the GD solution. For clearer visualization, the logarithmic scale of
correlation misalignment is presented in Figure 9. In essence, our findings show that W 1token yields unsatisfactory
outcomes, whereas W mm attains a significant correlation coefficient in alignment with our expectations. Ultimately,
our comprehensive SVM-equivalence W SVMeq further enhances correlation, lending support to our analytical formulas.
It’s noteworthy that SVM-equivalence displays higher predictability in a larger d regime (with an average correlation
exceeding 0.99). This phenomenon might be attributed to more frequent directional convergence in higher dimensions,
with overparameterization contributing to a smoother loss landscape, thereby expediting optimization.

• Figure 9 (lower) offers a scatterplot overview of the 500 random problem instances that were solved. The x-axis
represents the largest softmax probability over the masked set, denoted as maxi,τ siτ where τ ∈ Ōi . Meanwhile, the
y-axis indicates the predictivity of the SVM-equivalence, quantified as 1 − corr_coef(W, W SVMeq ). From this analysis,
two significant observations arise. Primarily, there exists an inverse correlation between softmax probability and
SVM-predictivity. This correlation is intuitive, as higher softmax probabilities signify a stronger divergence from our
desired masked set state (ideally set to 0). Secondly, as dimensionality (d) increases, softmax probabilities over the
masked set tend to converge towards the range of 10−15 (effectively zero). Simultaneously, attention SVM-predictivity
improves, creating a noteworthy correlation.

6.1 When does attention select multiple tokens?


In this section, we provide a concrete example where the optimal solution indeed requires combining multiple tokens in
a nontrivial fashion. Here, by nontrivial we mean that, we select more than 1 tokens from an input sequence but we
don’t select all of its tokens. Recall that, for linear prediction head, attention will ideally select the single token with
largest score for almost all datasets. Perhaps not surprisingly, this behavior will not persist for nonlinear prediction
heads. For instance in Figure 9, the GD output W aligned better in direction with W mm than W 1token . Specifically,
here we prove that if we make the function hY (x) := Y · h(x) concave, then optimal softmax map can select multiple
tokens in a controllable fashion. hY (x) can be viewed as generalization of the linear score function Y · v⊤ x. In the
example below, we induce concavity by incorporating a small −λ∥x∥2 term within a linear prediction head and setting
h(x) = v⊤ x − λ∥x∥2 with Y = 1.

Lemma 6 Given v ∈ Rd , recall the score vector γ = Xv. Without losing generality, assume γ is non-increasing. Define
gap
the vector of score gaps γgap ∈ RT −1 with entries γt = γt − γt+1 . Suppose all tokens within the input sequence are
orthonormal and for some τ ≥ 2, we have that
gap gap
τγτ /2 > γ1 . (9)
gap gap
Set h(x) = v⊤ x − λ∥x∥2 where τγτ /2 > λ > γ1 , ℓ(x) = −x, and Y = 1. Let ∆T denote the T -dimensional simplex.
Define the unconstrained softmax optimization associated to the objective h where we make s := S(XW z) a free variable,
namely,

min ℓ(h(Xs)) = min λ∥X⊤ s∥2 − v⊤ X⊤ s. (10)


s∈∆T s∈∆T

Then, the optimal solution s⋆ contains at least 2 and at most τ nonzero entries.

16
10 100 100
=0 # selected = 1

coefficient

coefficient
= 0.1
tokens

# selected = 2

1-correlation coefficient

1-correlation coefficient
8
of selected tokens

= 0.2 # selected = 5
=1 # selected = 9
6 10 1
# of# selected

1−correlation

1−correlation
10 1

2 10 2

0.0 0.5 1.0 1.5 2.0 2.5 3.0 0 1000 2000 3000 4000 5000 0 1000 2000 3000 4000 5000
λ Iterations
Iterations Iterations
Iterations
(a) λ vs # selected tokens (b) λ vs correlation coefficient (c) # selected tokens vs correlation coefficient

Figure 10: Behavior of GD when selecting multiple tokens. (a) The number of selected tokens increases with λ. (b)
Predictivity of attention SVM solutions for varying λ; Dotted curves depict the correlation corresponding to W mm
calculated via (Gen-SVM) and solid curves represent the correlation to W SVMeq , which incorporates the W fin correction.
(c) Similar to (b), but evaluating correlations over different numbers of selected tokens.

Figure 10 presents experimental findings concerning Lemma 6 across random problem instances. For this experiment,
we set n = 1, T = 10, and d = 10. The results are averaged over 100 random trials, with each trial involving the
generation of randomly orthonormal vectors x1t and the random sampling of vector v from the unit sphere. Similar
to the processing step in Figure 9, and following Figure 9 (lower) which illustrates that smaller softmax outputs over
masked sets correspond to higher correlation coefficients, we define the selected and masked token sets. Specifically,
tokens with softmax outputs > 10−3 are considered selected, while tokens with softmax outputs < 10−8 are masked.
Instances with softmax outputs between 10−8 and 10−3 are filtered out.
Figure 10(a) shows that the number of selected tokens grows alongside λ, a prediction consistent with Lemma
6. When λ = 0, the head h(x) = v⊤ x is linear, resulting in the selection of only one token per input. Conversely, as
λ exceeds a certain threshold (e.g., λ > 2.0 based on our criteria), the optimization consistently selects all tokens.
Figure 10(b) and 10(c) delve into the predictivity of attention SVM solutions for varying λ and different numbers of
selected tokens. The dotted curves in both figures represent 1 − corr_coef(W, W mm ), while solid curves indicate
1 − corr_coef(W, W SVMeq ), where W denotes the GD solution. Overall, the SVM-equivalence demonstrates a strong
correlation with the GD solution (consistently above 0.95). However, selecting more tokens (aligned with larger λ
values) leads to reduced predictivity.
To sum up, we have showcased the predictive capacity of the generalized SVM equivalence regarding the inductive
bias of 1-layer transformers with nonlinear heads. Nevertheless, it’s important to acknowledge that this section
represents an initial approach to a complex problem, with certain caveats requiring further investigation (e.g., the use of
filtering in Figures 9 and 10, and the presence of imperfect correlations). We aspire to conduct a more comprehensive
investigation, both theoretically and empirically, in forthcoming work.

7 Extending the Theory to Sequential and Causal Predictions


While our formulations (W-ERM & KQ-ERM) regress a single label Yi per (Xi , zi ), we extend in Appendix E our findings
to the sequence-to-sequence classification setting, where we output and classify all T tokens per input Xi , Z i ∈ RT ×d . In
this scenario, we prove that all of our RP guarantees remain intact after introducing a slight generalization of (Att-SVM).
Concretely, consider the following ERM problems for sequential and causal settings:
n T n T
1 XX 1 XX
Lseq (W) = ℓ(Yik · h(Xi⊤ S(XiW zik ))), and Lcsl (W) = ℓ(Yik · h(Xi⊤ Sk (XiW zik ))). (11)
n i=1 k=1 n i=1 k=1

Both equations train T tokens per input (Xi , Z i ) and, as usual, we recover self-attention via Z i ← Xi . For the causal
setting, we use the masked attention Sk (·) which calculates the softmax probabilities over the first k entries of its input
and sets the remaining T − k entries to zero. This way, the k’th prediction of the transformer only utilizes tokens from 1
to k and not the future tokens.

17
Let α = (αik )(n,k)
ik=(1,1) be tokens to be selected by attention (e.g. locally-optimal indices, see Def. 4). Then, the
sequential generalization of (Att-SVM) corresponding to Lseq (W) is given by

min ∥W∥F subj. to (xiαik − xit )⊤W zik ≥ 1 for all t , αik , k ∈ [T ], i ∈ [n] . (12)
W

We refer the reader to Appendix E which rigorously establishes all RP results for this sequential classification setting.
On the other hand, for the causal inference setting SVM should reflect the fact that the model is not allowed to make
use of future tokens. Note that the SVM constraints directly arise from softmax calculations. Thus, since attention
is masked over the indices t ≤ k and k ∈ [T ], the SVM constraints should apply over the same mask. Thus, we can
consider the straightforward generalization of global-optimality where optik is the token index with highest score over
indices t ≤ k and introduce an analogous definition for local-optimality. This leads to the following variation of (12),
which aims to select indices αik ∈ [k] from the first k tokens

min ∥W∥F subj. to (xiαik − xit )⊤W zik ≥ 1 for all t , αik , t ≤ k, k ∈ [T ], i ∈ [n].
W

Causal attention is a special case of a general attention mask which can restrict softmax to arbitrary subset of the
tokens. Such general masking can be handled similar to (7) by enforcing SVM constraints over the nonzero support of
the mask. Finally, the discussion so far extends our main theoretical results and focuses on selecting single token per
sequence. It can further be enriched by the generalized SVM equivalence developed in Section 6 to select and compose
multiple tokens by generalizing (Gen-SVM).

8 Related work
8.1 Implicit regularization, matrix factorization, and sparsity
Extensive research has delved into gradient descent’s implicit bias in separable classification tasks, often using logistic or
exponentially-tailed losses for margin maximization [SHN+ 18, GLSS18, NLG+ 19, JT21, KPOT21, MWG+ 20, JT20].
The findings have also been extended to non-separable data using gradient-based techniques [JT18, JT19, JDST20].
Implicit bias in regression problems and losses has been investigated, utilizing methods like mirror descent [WGL+ 20,
GLSS18, YKM20, VKR19, AW20a, AW20b, ALH21, SATA22]. Stochastic gradient descent has also been a subject of
interest regarding its implicit bias [LWM19, BGVV20, LR20, HWLM20, LWA22, DML21, ZWB+ 21]. This extends
to the implicit bias of adaptive and momentum-based methods [QQ19, WMZ+ 21, WMCL21, JST21].
In linear classification, GD iterations on logistic loss and separable datasets converge to the hard margin SVM
solution [SHN+ 18, RZH03, ZY05]. The attention layer’s softmax nonlinearity behaves similarly, potentially favoring
margin-maximizing solutions. Yet, the layer operates on tokens in input sequences, not for direct classification. Its bias
leans toward an (Att-SVM), selecting relevant tokens while suppressing others. However, formalizing this intuition
presents significant challenges: Firstly, our problem is nonconvex (even in terms of the W-parameterization), introducing
new challenges and complexities. Secondly, it requires the introduction of novel concepts such as locally-optimal
tokens, demanding a tailored analysis focused on the cones surrounding them. Our findings on the implicit bias
of (K, Q)-parameterization share conceptual similarities with [SRJ04], which proposes and analyzes a max-margin
matrix factorization problem. Similar problems have also been studied more recently in the context of neural-collapse
phenomena [PHD20] through an analysis of the implicit bias and regularization path of the unconstrained features
model with cross-entropy loss [TKVB22]. However, a fundamental distinction from these works lies in the fact that
attention solves a different max-margin problem that separate tokens. Moreover, our results on (K, Q)-parameterization
are inherently connected to the rich literature on low-rank factorization [GWB+ 17, ACHL19, TVS23, TBS+ 16, SS21],
stimulating further research. [TLZO23] is the first work to establish the connection between attention and SVM,
which is closest to our work. Here, we augment their framework, initially developed for a simpler attention model, to
transformers by providing the first guarantees for self/cross-attention layers, nonlinear prediction heads, and realistic
global convergence guarantees. While our Assumption B.2 and local-convergence analysis align with [TLZO23], our
contributions in global convergence analysis, benefits of overparameterization, and the generalized SVM-equivalence in
Section 6 are unique to this work.
It is well-known that attention map (i.e. softmax outputs) act as a feature selection mechanism and reveal the
tokens that are relevant to classification. On the other hand, sparsity and lasso regression (i.e. ℓ1 penalization)
[Don06, Tib96, TG07, CDS01, CRT06] have been pivotal tools in the statistics literature for feature selection. Softmax

18
and lasso regression exhibit interesting parallels: The Softmax output s = S(XW z) obeys ∥s∥ℓ1 = 1 by design. Softmax
is also highly receptive to being sparse because decreasing the temperature (i.e. scaling up the weights W) eventually
leads to a one-hot vector unless all logits are equal. We (also, [TLZO23]) have used these intuitions to formalize
attention as a token selection mechanism. This aspect is clearly visible in our primary SVM formulation (Att-SVM)
which selects precisely one token from each input sequence (i.e. hard attention). Section 6 has also demonstrated how
(Gen-SVM) can explain more general sparsity patterns by precisely selecting desired tokens and suppressing others. We
hope that this SVM-based token-selection viewpoint will motivate future work and deeper connections to the broader
feature-selection and compressed sensing literature.

8.2 Attention mechanism and transformers


Transformers, as highlighted by [VSP+ 17], revolutionized the domains of NLP and machine translation. Prior work on
self-attention [CDL16, PTDU16, PXS18, LFS+ 17] laid the foundation for this transformative paradigm. In contrast
to conventional models like MLPs and CNNs, self-attention models employ global interactions to capture feature
representations, resulting in exceptional empirical performance.
Despite their achievements, the mechanisms and learning processes of attention layers remain enigmatic. Recent
investigations [EGKZ22, SEO+ 22, ENM22, BV22, DCL21] have concentrated on specific aspects such as sparse
function representation, convex relaxations, and expressive power. Expressivity discussions concerning hard-attention
[Hah20] or attention-only architectures [DCL21] are connected to our findings when h(·) is linear. In fact, our work
reveals how linear h results in attention’s optimization dynamics to collapse on a single token whereas nonlinear h
provably requires attention to select and compose multiple tokens. This supports the benefits of the MLP layer for
expressivity of transformers. There is also a growing body of research aimed at a theoretical comprehension of in-context
learning and the role played by the attention mechanism [ASA+ 22, LIPO23, ACDS23, ZFB23, BCW+ 23, GRS+ 23].
[SEO+ 22] investigate self-attention with linear activation instead of softmax, while [ENM22] approximate softmax
using a linear operation with unit simplex constraints. Their primary goal is to derive convex reformulations for
training problems grounded in empirical risk minimization (ERM). In contrast, our methodologies, detailed in equations
(W-ERM) and (KQ-ERM), delve into the nonconvex domain.
[MRG+ 20, BALA+ 23] offer insights into the implicit bias of optimizing transformers. Specifically, [MRG+ 20]
provide empirical evidence that an increase in attention weights results in a sparser softmax, which aligns with our
theoretical framework. [BALA+ 23] study incremental learning and furnish both theory and numerical evidence that
increments of the softmax attention weights (KQ⊤ ) are low-rank. Our theory aligns with this concept, as the SVM
formulation (KQ-SVM) of (K, Q) parameterization inherently exhibits low-rank properties through the nuclear norm
objective, rank-m constraint, and implicit constraint induced by Lemma 1.
Several recent works [JSL22, LWLC23, TWCD23, NLL+ 23, ORST23, NNH+ 23, FGBM23] aim to delineate
the optimization and generalization dynamics of transformers. However, their findings usually apply under strict
statistical assumptions about the data, while our study offers a comprehensive optimization-theoretic analysis of the
attention model, establishing a formal linkage to max-margin problems and SVM geometry. This allows our findings
to encompass the problem geometry and apply to diverse datasets. Overall, the max-margin equivalence provides
a fundamental comprehension of the optimization geometry of transformers, offering a framework for prospective
research endeavors, as outlined in the subsequent section.

9 Discussion, Future Directions, and Open Problems


Our optimization-theoretic characterization of the self-attention model provides a comprehensive understanding of
its underlying principles. The developed framework, along with the research presented in [TLZO23], introduces new
avenues for studying transformers and language models. The key findings include:
✓ The optimization geometry of self-attention exhibits a fascinating connection to hard-margin SVM problems. By
leveraging linear constraints formed through outer products of token pairs, optimal input tokens can be effectively
separated from non-optimal ones.
✓ When gradient descent is employed without early-stopping, implicit regularization and convergence of self-attention
naturally occur. This convergence leads to the maximum margin solution when minimizing specific requirements using
logistic loss, exp-loss, or other smooth decreasing loss functions. Moreover, this implicit bias is unaffected by the step
size, as long as it is sufficiently small for convergence, and remains independent of the initialization process.

19
The fact that gradient descent leads to a maximum margin solution may not be surprising to those who are familiar with
the relationship between regularization path and gradient descent in linear and nonlinear neural networks [SHN+ 18,
GLSS18, NLG+ 19, JT21, MWG+ 20, JT20]. However, there is a lack of prior research or discussion regarding this
connection to the attention mechanism. Moreover, there has been no rigorous analysis or investigation into the exactness
and independence of this bias with respect to the initialization and step size. Thus, we believe our findings and insights
deepen our understanding of transformers and language models, paving the way for further research in this domain.
Below, we discuss some notable directions and highlight open problems that are not resolved by the existing theory.
• Convergence Rates: The current paper establishes asymptotic convergence of gradient descent; nonetheless, there
is room for further exploration to characterize non-asymptotic convergence rates. Indeed, such an exploration can
also provide valuable insights into the choice of learning rate, initialization, and the optimization method.

• Gradient descent on (K, Q) parameterization: We find it remarkable that regularization path analysis was able
to predict the implicit bias of gradient descent. Complete analysis of gradient descent is inherently connected to
the fundamental question of low-rank factorization [GWB+ 17, LMZ18]. We believe formalizing the implicit bias
of gradient descent under margin constraints presents an exciting open research direction for further research.
• Generalization analysis: An important direction is the generalization guarantees for gradient-based algorithms.
The established connection to hard-margin SVM can facilitate this because the SVM problem is amenable to
statistical analysis. This would be akin to how kernel/NTK analysis for deep nets enabled a rich literature on
generalization analysis for traditional deep learning.
• Global convergence of gradient descent: We lack a complete characterization of the directional convergence
of gradient descent. We ask: Where does gradient descent directionally-converge from arbitrary initialization
for 1-layer self-attention? The role of over-parameterization as conjectured in Section 5.2 and the notion of
locally-optimal directions discussed in Section 5 constitute important pieces of this puzzle (also see the discussion
in [TLZO23]).
• Realistic architectures: Naturally, we wish to explore whether max-margin equivalence can be extended to more
realistic settings: Can the theory be expanded to handle multi-head attention, multi-layer architectures, and MLP
nonlinearities? We believe the results in Section 6 take an important step towards this direction by including
analytical formulae for the implicit bias of the attention layer under nonlinear prediction heads.
• Jointly optimizing attention and prediction head: It would be interesting to study the joint optimization
dynamics of attention weights and prediction head h(·). This problem can be viewed as a novel low-rank
factorization type problem where h(·) and W are factors of the optimization problem, only, here, W passes through
the softmax nonlinearity. To this aim, [TLZO23] provides a preliminary geometric characterization of the implicit
bias for a simpler attention model using regularization path analysis. Such findings can potentially be generalized
to the analysis of gradient methods and full transformer block.

Acknowledgements
This work was supported by the NSF grants CCF-2046816 and CCF-2212426, Google Research Scholar award,
and Army Research Office grant W911NF2110312. The authors thank Xuechen Zhang, Ankit Singh Rawat, Mahdi
Soltanolkotabi, Jason Lee, Arkadas Ozakin, Ramya Korlakai Vinayak, and Babak Hassibi for helpful suggestions and
discussion.

References
[ACDS23] Kwangjun Ahn, Xiang Cheng, Hadi Daneshmand, and Suvrit Sra. Transformers learn to implement
preconditioned gradient descent for in-context learning. arXiv preprint arXiv:2306.00297, 2023.
[ACHL19] Sanjeev Arora, Nadav Cohen, Wei Hu, and Yuping Luo. Implicit regularization in deep matrix factorization.
Advances in Neural Information Processing Systems, 32, 2019.

20
[ALH21] Navid Azizan, Sahin Lale, and Babak Hassibi. Stochastic mirror descent on overparameterized nonlinear
models. IEEE Transactions on Neural Networks and Learning Systems, 33(12):7717–7727, 2021.
[ASA+ 22] Ekin Akyürek, Dale Schuurmans, Jacob Andreas, Tengyu Ma, and Denny Zhou. What learning algorithm
is in-context learning? investigations with linear models. arXiv:2211.15661, 2022.
[AW20a] Ehsan Amid and Manfred K Warmuth. Winnowing with gradient descent. In Conference on Learning
Theory, pages 163–182. PMLR, 2020.
[AW20b] Ehsan Amid and Manfred KK Warmuth. Reparameterizing mirror descent as gradient descent. Advances
in Neural Information Processing Systems, 33:8430–8439, 2020.
[AZLS19] Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-
parameterization. In International Conference on Machine Learning, pages 242–252. PMLR, 2019.
[BALA+ 23] Enric Boix-Adsera, Etai Littwin, Emmanuel Abbe, Samy Bengio, and Joshua Susskind. Transformers
learn through gradual rank increase. arXiv preprint arXiv:2306.07042, 2023.
[BCW+ 23] Yu Bai, Fan Chen, Huan Wang, Caiming Xiong, and Song Mei. Transformers as statisticians: Provable
in-context learning with in-context algorithm selection. arXiv preprint arXiv:2306.04637, 2023.
[BGVV20] Guy Blanc, Neha Gupta, Gregory Valiant, and Paul Valiant. Implicit regularization for deep neural
networks driven by an ornstein-uhlenbeck like process. In Conference on learning theory, pages 483–513.
PMLR, 2020.
[BLLT20] Peter L Bartlett, Philip M Long, Gábor Lugosi, and Alexander Tsigler. Benign overfitting in linear
regression. Proceedings of the National Academy of Sciences, 117(48):30063–30070, 2020.
[BMR+ 20] Tom Brown, Benjamin Mann, Nick Ryder, Melanie Subbiah, Jared D Kaplan, Prafulla Dhariwal, Arvind
Neelakantan, Pranav Shyam, Girish Sastry, Amanda Askell, and et al. Language models are few-shot
learners. In Advances in neural information processing systems, volume 33, pages 1877–1901, 2020.
[BV22] Pierre Baldi and Roman Vershynin. The quarks of attention. arXiv preprint arXiv:2202.08371, 2022.
[Car21] Marcus Carlsson. von neumann’s trace inequality for hilbert–schmidt operators. Expositiones Mathemati-
cae, 39(1):149–157, 2021.
[CDL16] Jianpeng Cheng, Li Dong, and Mirella Lapata. Long short-term memory-networks for machine reading.
In Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing, pages
551–561, Austin, Texas, November 2016. Association for Computational Linguistics.
[CDS01] Scott Shaobing Chen, David L Donoho, and Michael A Saunders. Atomic decomposition by basis pursuit.
SIAM review, 43(1):129–159, 2001.
[CLR+ 21] Lili Chen, Kevin Lu, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Misha Laskin, Pieter Abbeel,
Aravind Srinivas, and Igor Mordatch. Decision transformer: Reinforcement learning via sequence
modeling. In Advances in Neural Information Processing Systems, volume 34, pages 15084–15097, 2021.
[CRT06] Emmanuel J Candès, Justin Romberg, and Terence Tao. Robust uncertainty principles: Exact signal
reconstruction from highly incomplete frequency information. IEEE Transactions on information theory,
52(2):489–509, 2006.
[CSL+ 23] Yingyi Chen, Xi Shen, Yahui Liu, Qinghua Tao, and Johan AK Suykens. Jigsaw-vit: Learning jigsaw
puzzles in vision transformer. Pattern Recognition Letters, 166:53–60, 2023.
[DCL21] Yihe Dong, Jean-Baptiste Cordonnier, and Andreas Loukas. Attention is not all you need: Pure attention
loses rank doubly exponentially with depth. In International Conference on Machine Learning, pages
2793–2803. PMLR, 2021.
[DML21] Alex Damian, Tengyu Ma, and Jason Lee. Label noise sgd provably prefers flat global minimizers. arXiv
preprint arXiv:2106.06530, 2021.

21
[Don06] David L Donoho. Compressed sensing. IEEE Transactions on information theory, 52(4):1289–1306,
2006.
[DZPS18] Simon S Du, Xiyu Zhai, Barnabas Poczos, and Aarti Singh. Gradient descent provably optimizes
over-parameterized neural networks. arXiv preprint arXiv:1810.02054, 2018.
[EGKZ22] Benjamin L Edelman, Surbhi Goel, Sham Kakade, and Cyril Zhang. Inductive biases and variable creation
in self-attention mechanisms. In International Conference on Machine Learning, pages 5793–5831.
PMLR, 2022.
[ENM22] Tolga Ergen, Behnam Neyshabur, and Harsh Mehta. Convexifying transformers: Improving optimization
and understanding of transformer networks. arXiv:2211.11052, 2022.
[Faz02] Maryam Fazel. Matrix rank minimization with applications. PhD thesis, PhD thesis, Stanford University,
2002.
[FGBM23] Hengyu Fu, Tianyu Guo, Yu Bai, and Song Mei. What can a single attention layer learn? a study through
the random features lens. arXiv preprint arXiv:2307.11353, 2023.
[FXM+ 21] Haoqi Fan, Bo Xiong, Karttikeya Mangalam, Yanghao Li, Zhicheng Yan, Jitendra Malik, and Christoph
Feichtenhofer. Multiscale vision transformers. In Proceedings of the IEEE/CVF International Conference
on Computer Vision, pages 6824–6835, 2021.
[GLSS18] Suriya Gunasekar, Jason Lee, Daniel Soudry, and Nathan Srebro. Characterizing implicit bias in terms of
optimization geometry. In International Conference on Machine Learning, pages 1832–1841. PMLR,
2018.
[Goo23] Google. Try bard, an ai experiment by google. https://bard.google.com, 2023.
[GRS+ 23] Angeliki Giannou, Shashank Rajput, Jy-yong Sohn, Kangwook Lee, Jason D Lee, and Dimitris Papail-
iopoulos. Looped transformers as programmable computers. arXiv:2301.13196, 2023.
[GWB+ 17] Suriya Gunasekar, Blake E Woodworth, Srinadh Bhojanapalli, Behnam Neyshabur, and Nati Srebro.
Implicit regularization in matrix factorization. Advances in neural information processing systems, 30,
2017.
[Hah20] Michael Hahn. Theoretical limitations of self-attention in neural sequence models. Transactions of the
Association for Computational Linguistics, 8:156–171, 2020.
[HMX21] Daniel Hsu, Vidya Muthukumar, and Ji Xu. On the proliferation of support vectors in high dimensions. In
International Conference on Artificial Intelligence and Statistics, pages 91–99. PMLR, 2021.
[HWLM20] Jeff Z HaoChen, Colin Wei, Jason D Lee, and Tengyu Ma. Shape matters: Understanding the implicit bias
of the noise covariance. arXiv preprint arXiv:2006.08680, 2020.
[JDST20] Ziwei Ji, Miroslav Dudík, Robert E Schapire, and Matus Telgarsky. Gradient descent follows the
regularization path for general losses. In Conference on Learning Theory, pages 2109–2136. PMLR,
2020.
[JGH18] Arthur Jacot, Franck Gabriel, and Clément Hongler. Neural tangent kernel: Convergence and generaliza-
tion in neural networks. arXiv preprint arXiv:1806.07572, 2018.
[JLL21] Michael Janner, Qiyang Li, and Sergey Levine. Offline reinforcement learning as one big sequence
modeling problem. Advances in neural information processing systems, 34:1273–1286, 2021.
[JSL22] Samy Jelassi, Michael Eli Sander, and Yuanzhi Li. Vision transformers provably learn spatial structure.
In Alice H. Oh, Alekh Agarwal, Danielle Belgrave, and Kyunghyun Cho, editors, Advances in Neural
Information Processing Systems, 2022.
[JST21] Ziwei Ji, Nathan Srebro, and Matus Telgarsky. Fast margin maximization via dual acceleration. In
International Conference on Machine Learning, pages 4860–4869. PMLR, 2021.

22
[JT18] Ziwei Ji and Matus Telgarsky. Risk and parameter convergence of logistic regression. arXiv preprint
arXiv:1803.07300, 2018.
[JT19] Ziwei Ji and Matus Telgarsky. The implicit bias of gradient descent on nonseparable data. In Conference
on Learning Theory, pages 1772–1798. PMLR, 2019.
[JT20] Ziwei Ji and Matus Telgarsky. Directional convergence and alignment in deep learning. In H. Larochelle,
M. Ranzato, R. Hadsell, M. F. Balcan, and H. Lin, editors, Advances in Neural Information Processing
Systems, volume 33, pages 17176–17186. Curran Associates, Inc., 2020.
[JT21] Ziwei Ji and Matus Telgarsky. Characterizing the implicit bias via a primal-dual analysis. In Algorithmic
Learning Theory, pages 772–804. PMLR, 2021.
[KPOT21] Ganesh Ramachandra Kini, Orestis Paraskevas, Samet Oymak, and Christos Thrampoulidis. Label-
imbalanced and group-sensitive classification under overparameterization. Advances in Neural Information
Processing Systems, 34:18970–18983, 2021.
[KT19] Jacob Devlin Ming-Wei Chang Kenton and Lee Kristina Toutanova. Bert: Pre-training of deep bidirectional
transformers for language understanding. In Proceedings of NAACL-HLT, pages 4171–4186, 2019.
[LFS+ 17] Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu, Bing Xiang, Bowen Zhou, and Yoshua
Bengio. A structured self-attentive sentence embedding. In International Conference on Learning
Representations, 2017.
[LIPO23] Yingcong Li, M Emrullah Ildiz, Dimitris Papailiopoulos, and Samet Oymak. Transformers as algorithms:
Generalization and stability in in-context learning. In International Conference on Machine Learning,
2023.
[LL18] Yuanzhi Li and Yingyu Liang. Learning overparameterized neural networks via stochastic gradient descent
on structured data. Advances in neural information processing systems, 31, 2018.
[LL19] Kaifeng Lyu and Jian Li. Gradient descent maximizes the margin of homogeneous neural networks. arXiv
preprint arXiv:1906.05890, 2019.
[LLC+ 21] Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, and Baining Guo. Swin
transformer: Hierarchical vision transformer using shifted windows. In Proceedings of the IEEE/CVF
International Conference on Computer Vision, pages 10012–10022, 2021.
[LMZ18] Yuanzhi Li, Tengyu Ma, and Hongyang Zhang. Algorithmic regularization in over-parameterized matrix
sensing and neural networks with quadratic activations. In Conference On Learning Theory, pages 2–47.
PMLR, 2018.
[LR20] TENGYUAN LIANG and ALEXANDER RAKHLIN. Just interpolate: Kernel “ridgeless” regression can
generalize. The Annals of Statistics, 48(3):1329–1347, 2020.
[LWA22] Zhiyuan Li, Tianhao Wang, and Sanjeev Arora. What happens after SGD reaches zero loss? –a mathemat-
ical framework. In International Conference on Learning Representations, 2022.
[LWLC23] Hongkang Li, Meng Wang, Sijia Liu, and Pin-Yu Chen. A theoretical understanding of shallow vision
transformers: Learning, generalization, and sample complexity. arXiv preprint arXiv:2302.06015, 2023.
[LWM19] Yuanzhi Li, Colin Wei, and Tengyu Ma. Towards explaining the regularization effect of initial large
learning rate in training neural networks. arXiv preprint arXiv:1907.04595, 2019.
[MNS+ 21] Vidya Muthukumar, Adhyyan Narang, Vignesh Subramanian, Mikhail Belkin, Daniel Hsu, and Anant
Sahai. Classification vs regression in overparameterized regimes: Does the loss function matter? The
Journal of Machine Learning Research, 22(1):10104–10172, 2021.
[MRG+ 20] William Merrill, Vivek Ramanujan, Yoav Goldberg, Roy Schwartz, and Noah Smith. Effects of pa-
rameter norm growth during transformer training: Inductive bias from gradient descent. arXiv preprint
arXiv:2010.09697, 2020.

23
[MWG+ 20] Edward Moroshko, Blake E Woodworth, Suriya Gunasekar, Jason D Lee, Nati Srebro, and Daniel Soudry.
Implicit bias in deep linear classification: Initialization scale vs training accuracy. Advances in neural
information processing systems, 33:22182–22193, 2020.
[NLG+ 19] Mor Shpigel Nacson, Jason Lee, Suriya Gunasekar, Pedro Henrique Pamplona Savarese, Nathan Srebro,
and Daniel Soudry. Convergence of gradient descent on separable data. In The 22nd International
Conference on Artificial Intelligence and Statistics, pages 3420–3428. PMLR, 2019.
[NLL+ 23] Lorenzo Noci, Chuning Li, Mufan Bill Li, Bobby He, Thomas Hofmann, Chris Maddison, and Daniel M
Roy. The shaped transformer: Attention models in the infinite depth-and-width limit. arXiv preprint
arXiv:2306.17759, 2023.
[NNH+ 23] Tan Minh Nguyen, Tam Minh Nguyen, Nhat Ho, Andrea L Bertozzi, Richard Baraniuk, and Stanley Osher.
A primal-dual framework for transformers and neural networks. In The Eleventh International Conference
on Learning Representations, 2023.
[OH10] Samet Oymak and Babak Hassibi. New null space results and recovery thresholds for matrix rank
minimization. arXiv preprint arXiv:1011.6326, 2010.

[OMFH11] Samet Oymak, Karthik Mohan, Maryam Fazel, and Babak Hassibi. A simplified approach to recovery con-
ditions for low rank matrices. In 2011 IEEE International Symposium on Information Theory Proceedings,
pages 2318–2322. IEEE, 2011.
[Ope22] OpenAI. OpenAI: Introducing ChatGPT. https://openai.com/blog/chatgpt, 2022.
[Ope23] OpenAI. Gpt-4 technical report. arXiv preprint arXiv:2303.08774, 2023.

[ORST23] Samet Oymak, Ankit Singh Rawat, Mahdi Soltanolkotabi, and Christos Thrampoulidis. On the role of
attention in prompt-tuning. In International Conference on Machine Learning, 2023.
[OS19] Samet Oymak and Mahdi Soltanolkotabi. Overparameterized nonlinear learning: Gradient descent takes
the shortest path? In International Conference on Machine Learning, pages 4951–4960. PMLR, 2019.

[PHD20] Vardan Papyan, XY Han, and David L Donoho. Prevalence of neural collapse during the terminal phase
of deep learning training. Proceedings of the National Academy of Sciences, 117(40):24652–24663, 2020.
[PTDU16] Ankur Parikh, Oscar Täckström, Dipanjan Das, and Jakob Uszkoreit. A decomposable attention model
for natural language inference. In Proceedings of the 2016 Conference on Empirical Methods in Natural
Language Processing, pages 2249–2255, Austin, Texas, November 2016. Association for Computational
Linguistics.
[PXS18] Romain Paulus, Caiming Xiong, and Richard Socher. A deep reinforced model for abstractive summariza-
tion. In International Conference on Learning Representations, 2018.
[QQ19] Qian Qian and Xiaoyuan Qian. The implicit bias of adagrad on separable data. Advances in Neural
Information Processing Systems, 32, 2019.

[RFP10] Benjamin Recht, Maryam Fazel, and Pablo A Parrilo. Guaranteed minimum-rank solutions of linear
matrix equations via nuclear norm minimization. SIAM review, 52(3):471–501, 2010.
[RSR+ 20] Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou,
Wei Li, and Peter J Liu. Exploring the limits of transfer learning with a unified text-to-text transformer.
Journal of Machine Learning Research, 21(1):5485–5551, 2020.
[RXH11] Benjamin Recht, Weiyu Xu, and Babak Hassibi. Null space conditions and thresholds for rank minimiza-
tion. Mathematical programming, 127:175–202, 2011.
[RZH03] Saharon Rosset, Ji Zhu, and Trevor Hastie. Margin maximizing loss functions. Advances in neural
information processing systems, 16, 2003.

24
[SATA22] Haoyuan Sun, Kwangjun Ahn, Christos Thrampoulidis, and Navid Azizan. Mirror descent maximizes
generalized margin and can be implemented efficiently. Advances in Neural Information Processing
Systems, 35:31089–31101, 2022.
[SEO+ 22] Arda Sahiner, Tolga Ergen, Batu Ozturkler, John Pauly, Morteza Mardani, and Mert Pilanci. Unraveling
attention via convex duality: Analysis and interpretations of vision transformers. In International
Conference on Machine Learning, pages 19050–19088. PMLR, 2022.
[SHN+ 18] Daniel Soudry, Elad Hoffer, Mor Shpigel Nacson, Suriya Gunasekar, and Nathan Srebro. The implicit
bias of gradient descent on separable data. The Journal of Machine Learning Research, 19(1):2822–2878,
2018.
[SPR18] Arun Suggala, Adarsh Prasad, and Pradeep K Ravikumar. Connecting optimization and regularization
paths. Advances in Neural Information Processing Systems, 31, 2018.
[SRJ04] Nathan Srebro, Jason Rennie, and Tommi Jaakkola. Maximum-margin matrix factorization. Advances in
neural information processing systems, 17, 2004.
[SS21] Dominik Stöger and Mahdi Soltanolkotabi. Small random initialization is akin to spectral learning: Opti-
mization and generalization guarantees for overparameterized low-rank matrix reconstruction. Advances
in Neural Information Processing Systems, 34:23831–23843, 2021.
[TBS+ 16] Stephen Tu, Ross Boczar, Max Simchowitz, Mahdi Soltanolkotabi, and Ben Recht. Low-rank solutions
of linear matrix equations via procrustes flow. In International Conference on Machine Learning, pages
964–973. PMLR, 2016.

[TCD+ 21] Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles, and Hervé
Jégou. Training data-efficient image transformers & distillation through attention. In International
Conference on Machine Learning, pages 10347–10357. PMLR, 2021.
[TG07] Joel A Tropp and Anna C Gilbert. Signal recovery from random measurements via orthogonal matching
pursuit. IEEE Transactions on information theory, 53(12):4655–4666, 2007.

[Tib96] Robert Tibshirani. Regression shrinkage and selection via the lasso. Journal of the Royal Statistical
Society Series B: Statistical Methodology, 58(1):267–288, 1996.
[TKVB22] Christos Thrampoulidis, Ganesh Ramachandra Kini, Vala Vakilian, and Tina Behnia. Imbalance trouble:
Revisiting neural-collapse geometry. Advances in Neural Information Processing Systems, 35:27225–
27238, 2022.

[TLI+ 23] Hugo Touvron, Thibaut Lavril, Gautier Izacard, Xavier Martinet, Marie-Anne Lachaux, Timothée Lacroix,
Baptiste Rozière, Naman Goyal, Eric Hambro, Faisal Azhar, et al. Llama: Open and efficient foundation
language models. arXiv preprint arXiv:2302.13971, 2023.
[TLZO23] Davoud Ataee Tarzanagh, Yingcong Li, Xuechen Zhang, and Samet Oymak. Margin maximization in
attention mechanism. arXiv preprint arXiv:2306.13596, 2023.
[TVS23] Nadav Timor, Gal Vardi, and Ohad Shamir. Implicit regularization towards rank minimization in relu
networks. In International Conference on Algorithmic Learning Theory, pages 1429–1459. PMLR, 2023.
[TWCD23] Yuandong Tian, Yiping Wang, Beidi Chen, and Simon Du. Scan and snap: Understanding training
dynamics and token composition in 1-layer transformer. arXiv:2305.16380, 2023.

[VKR19] Tomas Vaskevicius, Varun Kanade, and Patrick Rebeschini. Implicit regularization for optimal sparse
recovery. Advances in Neural Information Processing Systems, 32:2972–2983, 2019.
[VSP+ 17] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz
Kaiser, and Illia Polosukhin. Attention is all you need. Advances in neural information processing systems,
30, 2017.

25
[WGL+ 20] Blake Woodworth, Suriya Gunasekar, Jason D Lee, Edward Moroshko, Pedro Savarese, Itay Golan, Daniel
Soudry, and Nathan Srebro. Kernel and rich regimes in overparametrized models. In Conference on
Learning Theory, pages 3635–3673. PMLR, 2020.
[WMCL21] Bohan Wang, Qi Meng, Wei Chen, and Tie-Yan Liu. The implicit bias for adaptive optimization algorithms
on homogeneous neural networks. In International Conference on Machine Learning, pages 10849–10858.
PMLR, 2021.
[WMZ+ 21] Bohan Wang, Qi Meng, Huishuai Zhang, Ruoyu Sun, Wei Chen, and Zhi-Ming Ma. Momentum doesn’t
change the implicit bias. arXiv preprint arXiv:2110.03891, 2021.
[WT22] Ke Wang and Christos Thrampoulidis. Binary classification of gaussian mixtures: Abundance of support
vectors, benign overfitting, and regularization. SIAM Journal on Mathematics of Data Science, 4(1):260–
284, 2022.
[WWX+ 22] Haixu Wu, Jialong Wu, Jiehui Xu, Jianmin Wang, and Mingsheng Long. Flowformer: Linearizing
transformers with conservation flows. In International Conference on Machine Learning, pages 24226–
24242, 2022.

[YKM20] Chulhee Yun, Shankar Krishnan, and Hossein Mobahi. A unifying view on implicit bias in training linear
neural networks. arXiv preprint arXiv:2010.02501, 2020.
[ZFB23] Ruiqi Zhang, Spencer Frei, and Peter L Bartlett. Trained transformers learn linear models in-context.
arXiv preprint arXiv:2306.09927, 2023.
[ZWB+ 21] Difan Zou, Jingfeng Wu, Vladimir Braverman, Quanquan Gu, Dean P Foster, and Sham Kakade. The
benefits of implicit regularization from sgd in least squares problems. Advances in Neural Information
Processing Systems, 34:5456–5468, 2021.
[ZY05] Tong Zhang and Bin Yu. Boosting with early stopping: Convergence and consistency. Annals of Statistics,
page 1538, 2005.

26
Roadmap. The appendix is organized as follows:

• Appendix A provides the proof of Theorem 1.

• Appendix B provides auxiliary lemmas about the training risk.


• Appendix C presents the proofs for the global convergence of gradient descent (Section 4).
• Appendix D presents the proofs for the local convergence of gradient descent (Section 5).

• Appendix E provides a general regularization path analysis. This analysis addresses the inductive bias of the
attention layer for general norm objectives and beyond-linear prediction heads under a sequence-to-sequence
classification model. The seq2seq aspect also goes beyond our results in the main body where we predict using
single output token (Sections 3 and 5.4).
• Appendix F provides additional experiments and their discussion.

Contents
A Proof of Theorem 1: Separability Under Mild Over-Parameterization 27

B Auxiliary Lemmas 29
B.1 Proof of Lemma 1 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
B.2 Proof of Lemma 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
B.3 Proof of Lemma 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 29
B.4 A useful lemma for gradient descent analysis . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31

C Global Convergence of Gradient Descent 32


C.1 Divergence of norm of the iterates . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32
C.1.1 Proof of Theorem 3 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 33
C.2 Global convergence under good initial gradient . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 34
C.2.1 Proof of Theorem 4 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 38

D Local Convergence of Gradient Descent 41


D.1 Proof of Theorem 5 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 46

E Convergence of Regularization Path for Sequence-to-Sequence Setting 49


E.1 Local regularization path and proof of Theorem 7 . . . . . . . . . . . . . . . . . . . . . . . . . . . . 51
E.2 Global regularization path . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 55
E.2.1 Proof of Theorem 2 . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 55

F Supporting Experiments 57

A Proof of Theorem 1: Separability Under Mild Over-Parameterization


We denote Kronecker product of two matrices via ⊗. Additionally, given W ∈ Rd×d , let us denote its vectorization
2
w = vec(W) ∈ Rd . We first note that separation is implied by the linear independence of the constraints. Specifically,
we are interested in guaranteeing

⟨w, fit ⟩ ≥ 1 for all i ∈ [n], t , αi , where fit := (xiαi − xit ) ⊗ zi .

Note that, the inequality constraints above are feasible as soon as fit ’s are linearly independent. Thus, we will instead
prove linear independence of the vectors ( fit )i∈[n],t,αi . Also note that, since there are finitely many α choices, if we show
almost sure separation for a fixed but arbitrary α choice, through union bound, we recover the result for all α. Thus, we
prove the result for a fixed α choice.

27
2
We will prove this result inductively. Let Mn−1 ∈ R(n−1)(T −1)×d denote the matrix whose rows are given by the
features ( fit )i∈[n−1],t,αi . Suppose the result is correct for n − 1, thus, Mn−1 is full row-rank almost surely (post random
Gaussian perturbation). Now, fix Mn−1 and, conditioned on Mn−1 being full row-rank, let us show that Mn is also full
row-rank almost surely. To prove this, consider the n’th example (Xn , zn ). Let ( gt )Tt=1 , h ∈ Rd be random vectors with
i.i.d. N(0, σ2 ) entries. Consider the perturbed input Xn′ ∈ RT ×d with tokens x′nt = xnt + gt and z′n = zn + h. Note that
(T −1)×d2
for self-attention, we set zn = xn1 and h = "g1 . From # these, create the matrix M̃n ∈ R with rows ( fnt′ )t,αn where
M̃n
fnt′ = (x′nαn − x′nt ) ⊗ z′n . Observe that Mn = . To conclude with the result, we will apply Lemma 7. To apply this
Mn−1
lemma, we have two claims.
Claim 1: Let z̄n be the projection of z′n on the orthogonal complement of (zi )n−1 i=1 . Consider the matrix M̄n with rows
f¯nt = (x′nαn − x′nt ) ⊗ z̄n for t , αn . M̄n is rank T − 1 almost surely whenever d ≥ max(T − 1, n).
To see this claim, first denote the orthogonal complement of the span of the vectors (zi )n−1 i=1 by Zn−1 . The span of
the vectors (zi )n−1i=1 is at most n − 1 dimensional and, since d ≥ n, dim(Z n−1 ) ≥ 1. Consequently, z̄n , 0 almost surely
because the Gaussian variable zn + h will have nonzero projection on Zn−1 almost surely. Secondly, let X̄ ∈ R(T −1)×d be
the matrix whose rows are equal to x′nαn − x′nt for t , αn . X̄ is full row-rank almost surely, this is because conditioned
on gαn , the matrix X̄ is written as X̄ = X̃ + G where X̃ is deterministic and G is i.i.d. Gaussian. The latter perturbation
ensures full row-rank almost surely whenever T − 1 ≤ d. Finally, note that M̄n = X̄ ⊗ z̄n . Since the rank of the Kronecker
product is multiplicative, we conclude with the claim.
2
Claim 2: Let S n−1 ⊂ Rd be the null space of Mn−1 . There exists a subspace P ⊆ S n−1 such that rows of M̄n are
projections of the rows of Mn on P, that is, f¯nt = ΠP ( fnt′ ) where Π denotes set projection.
To show this claim, let us consider the matrix forms of the vectorized features i.e. let us work with Rd×d rather than
d2
R . Denote the notation change as Fit = (xiαi − xit )z⊤i ↔ fit = (xiαi − xit ) ⊗ zi . Recall that Zn−1 denotes the orthogonal
complement of (zi )n−1 i=1 . Define Q to be the set of matrices in R
d×d
whose column space lies in Zn−1 and P to be the
vectorization of Q. We first show that P is a subset of the null space of S n−1 . To see this, fix any matrix A ∈ P and a row
fit from Mn−1 . Matricized fit can be written as Fit = az⊤i for zi ∈ Zn−1 ⊥
. Since A ∈ Q, this implies ⟨Fit , A⟩ = a⊤ Azi = 0
as Azi = 0. This holds for all Fit , thus, vectorized( A) ∈ null(S n−1 ).
Next, we need to show that f¯nt is the projection of fnt′ on P. To see this, we will show that f¯nt ∈ P whereas
fnt − f¯nt ∈ P⊥ for all t. Write F′nt = matricized( fnt′ ) = az′ ⊤n . We have that F̄nt = matricized( f¯nt ) = az̄⊤n where

z̄n = ΠZn−1 (z′n ). This implies F̄nt ∈ Q and f¯nt ∈ P. Similarly, since z′n − z̄n ∈ Zn−1⊥ ′
"which#implies Fnt − F̄nt ∈ Q .

M̃n
To conclude with the proof, observe that, through Claims 1 and 2, Mn = satisfies the requirements of
Mn−1
Lemma 7 almost surely, namely, projection of M̃n onto a subset of the null space of Mn−1 being full rank. Thus, Mn is
full rank almost surely. ■
Lemma 7 Let A ∈ Rn×p , B ∈ Rm×p . Suppose n + m ≤ p and A is full row-rank. Denote the null space of A by S ⊥A . Let
P be a subspace that is its subset i.e. P ⊆ S ⊥A . Let B′ be the matrix obtained by projecting each of row of B on P and
suppose B′ is full rank. Then, the concatenation C = [ A; B] is full row-rank.
′ m ′
Proof. Let (ai )ni=1 , (bi )m
i=1 , (bi )i=1 be the rows of A, B, B , respectively. Suppose the set of rows of A and B are linearly
dependent. Then, for some (ci )ni=1 , (c′i )m i=1 (which are not all-zeros), we have that
n
X m
X
ci ai + c′i bi = 0. (13)
i=1 i=1

We now rewrite this as follows to decouple P and P⊥ :


n
X m
X m
X
ci ai + c′i b′i + c′i (b′i − bi ) = 0.
i=1 i=1 i=1

Projecting above inequality to P, we find that i=1 c′i b′i = 0. Since (b′i )m
i=1 are linearly independent, we find ci = 0 for
Pm ′

all i ∈ [m]. This implies i=1 ci ai = 0. Since (ai )i=1 are linearly independent, this implies ci = 0 for all i ∈ [n]. Thus,
Pn n

(13) can only hold if all coefficients are zero which is a contradiction.

28
B Auxiliary Lemmas
B.1 Proof of Lemma 1
Let W⋄mm denote either solution of (Att-SVM) or (Att-SVM⋆ ). We claim that W⋄mm is at most rank n. Suppose the claim
is wrong and row space of W⋄mm does not lie within S = span({zi }ni=1 ). Let W = ΠS (W⋄mm ) denote the matrix obtained
by projecting the rows of W⋄mm on S. Observe that W satisfies all SVM constraints since W zi = W⋄mm zi for all i ∈ [n].
For Frobenius norm, using W⋄mm , W, we obtain a contradiction via ∥W⋄mm ∥2F = ∥W∥2F + ∥W⋄mm − W∥2F > ∥W∥2F . For
nuclear norm, we can write W = UΣV ⊤ with Σ ∈ Rr×r where r is dimension of S and column_span(V) = S.
To proceed, we split the problem into two scenarios.
Scenario 1: Let U⊥ , V⊥ be orthogonal complements of U, V – viewing matrices with orthonormal columns as subspaces.
Suppose U⊥⊤W⋄mm V⊥ , 0. Then, singular value inequalities (which were also used in earlier works on nuclear norm
analysis [RXH11, OH10, OMFH11]) guarantee that ∥W⋄mm ∥⋆ ≥ ∥U⊤W⋄mm V∥⋆ + ∥U⊥⊤W⋄mm V⊥ ∥⋆ > ∥W∥⋆ .
Scenario 2: Now suppose U⊥⊤W⋄mm V⊥ = 0. Since W⋄mm V⊥ , 0, this implies U⊤W⋄mm V⊥ , 0. Let W ′ = UU⊤W⋄mm
which is a rank-r matrix. Since W ′ is a subspace projection, we have ∥W ′ ∥⋆ ≤ ∥W⋄mm ∥⋆ . Next, observe that
∥W∥⋆ = trace(U⊤WV) = trace(U⊤W ′ V). On the other hand, trace(U⊤W ′ V) < ∥W ′ ∥⋆ because the equality
in von Neumann’s trace inequality happens if and only if the two matrices we are inner-producting, namely (W ′ , UV ⊤ ),
share a joint set of singular vectors [Car21]. However, this is not true as the row space of W⋄mm does not lie within S.
Thus, we obtain ∥W∥⋆ < ∥W ′ ∥⋆ ≤ ∥W⋄mm ∥⋆ concluding the proof via contradiction. ■

B.2 Proof of Lemma 2


We first show that L(W) > L⋆ = 1n ni=1 ℓ(γiopti ). The token at the output of the attention layer is given by ai = Xi⊤ si ,
P
where S(XiW zi ) = si . Here, ai can be written as ai = t∈[T ] sit xit , where sit ≥ 0 and t∈[T ] sit = 1. To proceed, using
P P
the linearity of h(x) = v x, we find that

n n
1X 1X X
L(W) = ℓ(Yi · h(ai )) = ℓ(Yi · sit h(xit ))
n i=1 n i=1 t∈[T ]
n n
1X 1X
≥ ℓ(Yi · h(xiopti )) = ℓ(γiopti ) = L⋆ . (14)
n i=1 n i=1

Here, the inequality follows since γit = Yi · h(xit ) = Yi · v⊤ xit ≤ γiopti by Definition 1 and strictly-decreasing nature of
the loss ℓ due to Assumption A.
On the other hand, since not all tokens are optimal, there exists a token index (i, t) for which Yi · h(xit ) < Yi · h(xiopti ).
Since all softmax entries obey sit > 0 for finite W, this implies the strict inequality ℓ(Yi · h(ai )) > ℓ(Yi · h(xiopti )). This
leads to the desired conclusion L(W) > L⋆ .
Next, we show that if (Att-SVM) is feasible i.e. there exists a W separating some optimal indices (opti )ni=1 from the
other tokens, then limR→∞ L(R · W) = L⋆ . Note that, this assumption does not exclude the existence of other optimal
indices. This implies that, letting limR→∞ S(Xi (R · W)zi ) saturates the softmax and will be equal to the indicator function
at opti for all inputs i ∈ [n]. Thus, sit → 0 for t , opti and sit → 1 for t = opti . Using M1 -Lipschitzness of ℓ, we can
write
ℓ(Yi · h(xiopti )) − ℓ(Yi · h(ai )) ≤ M1 h(ai ) − h(xiopti ) .
Since h is linear, it is ∥v∥-Lipschitz implying

ℓ(Yi · h(xiopti )) − ℓ(Yi · h(ai )) ≤ M1 ∥v∥ · ∥ai − xiopti ∥.

Since ai → xiopti as R → ∞, (14) gives limR→∞ L(R · W) = L⋆ . ■

B.3 Proof of Lemma 4


Let
γi = Yi · Xi v, hi = XiW zi .

29
From Assumption A, ℓ : R → R is differentiable. Hence, the gradient evaluated at W is given by
n
1 X ′ ⊤ 
∇L(W) = ℓ γi S(hi ) · Xi⊤ S′ (hi )γi z⊤i , (15)
n i=1

where
S′ (h) = diag (S(h)) − S(h)S(h)⊤ ∈ RT ×T . (16)
Note that
∥S′ (h)∥ ≤ ∥S′ (h)∥F ≤ 1. (17)
d×d
Hence, for any W, Ẇ ∈ R , i ∈ [n], we have

S(hi ) − S( ḣi ) ≤ hi − ḣi ≤ ∥Xi ∥ ∥zi ∥ W − Ẇ F


, (18a)

where ḣi = XiẆ zi .


Similarly,

S′ (hi ) − S′ ( ḣi ) F
≤ S(hi ) − S( ḣi ) + S(hi )S(hi )⊤ − S( ḣi )S( ḣi )⊤ F
≤ 3∥Xi ∥ ∥zi ∥ W − Ẇ F
. (18b)

Next, for any W, Ẇ ∈ Rd×d , we get


n
1 X ′ ⊤   
∇L(W) − ∇L(Ẇ) F
≤ ℓ γi S(hi ) · zi γ⊤i S′ (hi )Xi − ℓ′ γ⊤i S( ḣi ) · zi γ⊤i S′ ( ḣi )Xi
n i=1 F

n
1 X    
≤ zi γ⊤i S′ ( ḣi )Xi F
ℓ′ γ⊤i S(hi ) − ℓ′ γ⊤i S( ḣi )
n i=1
n
1X  
+ ℓ′ γ⊤i S(hi ) zi γ⊤i S′ (hi )Xi − zi γ⊤i S′ ( ḣi )Xi F
n i=1
n
1 X
≤ M0 ∥γi ∥2 ∥zi ∥ ∥Xi ∥ S(hi ) − S( ḣi )
n i=1
n
1X
+ M1 ∥γi ∥ ∥zi ∥ ∥Xi ∥ S′ (hi ) − S′ ( ḣi ) F
, (19)
n i=1

where the second inequality follows from the fact that |ab − cd| ≤ |d||a − c| + |a||b − d| and the third inequality uses
Assumption A and (17).
Substituting (18a) and (18b) into (19), we get
n
1 X 
∇L(W) − ∇L(Ẇ) F
≤ M0 ∥γi ∥2 ∥zi ∥2 ∥Xi ∥2 + 3M1 ∥γi ∥ ∥zi ∥2 ∥Xi ∥2 ∥W − Ẇ∥F
n i=1
n
1 X 
≤ M0 ∥v∥2 ∥zi ∥2 ∥Xi ∥4 + 3M1 ∥v∥ ∥zi ∥2 ∥Xi ∥3 ∥W − Ẇ∥F
n i=1
≤ LW ∥W − Ẇ∥F ,

where LW is defined in (3).


Let gi = Xi KQ⊤ zi . We have
n
1 X ′ ⊤ 
∇ K L(K, Q) = ℓ γi S( gi ) · zi γ⊤i S′ ( gi )Xi Q, (20a)
n i=1
n
1 X ′ ⊤ 
∇Q L(K, Q) = ℓ γi S(gi ) · Xi⊤ S′ ( gi )γi z⊤i K. (20b)
n i=1

30
By the similar argument as in (19), for any Q and Q̇ ∈ Rd×m , we have
n
∥K∥ X ′  ⊤   
∇Q L(K, Q) − ∇Q L(K, Q̇) F
≤ ℓ γi S(hi ) · zi γ⊤i S′ (hi )Xi − ℓ′ γ⊤i S( ḣi ) · zi γ⊤i S′ ( ḣi )Xi
n i=1 F

≤ LW ∥K∥ ∥Q − Q̇∥F . (21)

Similarly, for any K, K̇ ∈ Rd×m , we get

∇ K L(K, Q) − ∇ K L( K̇, Q) F
≤ LW ∥Q∥ ∥K − K̇∥F .

B.4 A useful lemma for gradient descent analysis


Lemma 8 For any X ∈ RT ×d , W, V ∈ Rd×d and z, v ∈ Rd , let a = XV z, s = S(XW z), and γ = Xv. Set

Γ = sup |γt − γτ | and A = sup ∥at ∥.


t,τ∈[T ] t∈[T ]

We have that
T
X
a⊤ diag(s)γ − a⊤ ss⊤ γ − (a1 − at )st (γ1 − γt ) ≤ 2ΓA(1 − s1 )2 .
t≥2

Proof. The proof is similar to [TLZO23, Lemma 4], but for the sake of completeness, we provide it here. Set
γ̄ = Tt=1 γt st . We have
P

T
X
γ1 − γ̄ = (γ1 − γt )st , and |γ1 − γ̄| ≤ Γ(1 − s1 ).
t≥2

Then,
T
X T
X T
X
a⊤ diag(s)γ − a⊤ ss⊤ γ = at γt st − at st γt s t
t=1 t=1 t=1
T
X
= a1 s1 (γ1 − γ̄) − at st (γ̄ − γt ). (22)
t≥2

Since
T
X T
X
at st (γ̄ − γt ) − at st (γ1 − γt ) ≤ AΓ(1 − s1 )2 ,
t≥2 t≥2

we obtain5
T
X
a⊤ diag(s)γ − a⊤ ss⊤ γ = a1 s1 (γ1 − γ̄) − at st (γ1 − γt ) ± AΓ(1 − s1 )2
t≥2
T
X T
X
= a1 s1 (γ1 − γt )st − at st (γ1 − γt ) ± AΓ(1 − s1 )2
t≥2 t≥2
T
X
= (a1 s1 − at )st (γ1 − γt ) ± AΓ(1 − s1 )2
t≥2
T
X
= (a1 − at )st (γ1 − γt ) ± 2AΓ(1 − s1 )2 .
t≥2

5 For simplicity, we use ± on the right hand side to denote the upper and lower bounds.

31
Here, ± on the right handside uses the fact that
T
X T
X
(a1 s1 − a1 )st (γ1 − γt ) ≤ (1 − s1 )AΓ st = (1 − s1 )2 AΓ.
t≥2 t≥2

C Global Convergence of Gradient Descent


C.1 Divergence of norm of the iterates
The next lemma establishes the descent property of gradient descent for L(W) under Assumption A.

Lemma 9 (Descent Lemma) Under Assumption A, if η ≤ 1/LW , then for any initialization W(0), Algorithm W-GD
satisfies:
η
L(W(k + 1)) − L(W(k)) ≤ − ∥∇L(W(k))∥2F , (23)
2
for all k ≥ 0. Additionally, it holds that ∞k=0 ∥∇L (W(k))∥F < ∞, and limk→∞ ∥∇L (W (k))∥F = 0.
P 2 2

Proof. The proof is similar to [TLZO23, Lemma 6].


The lemma below reveals that the correlation between the training loss’s gradient at any arbitrary matrix W and the
attention SVM solution W mm is negative. Consequently, for any finite W, ⟨∇L(W), W mm ⟩ cannot be equal to zero.

Lemma 10 Let W mm be the SVM solution of (Att-SVM). Suppose Assumptions A and B hold. Then, for all W ∈ Rd×d ,
the training loss (W-ERM) obeys ⟨∇L(W), W mm ⟩ < 0.

Proof. Let
h̄i = XiW mm zi , γi = Yi · Xi v, and hi = XiW zi . (24)
Let us recall the gradient evaluated at W which is given by
n
1 X ′ ⊤ 
∇L(W) = ℓ γi S(hi ) · Xi⊤ S′ (hi )γi z⊤i , (25)
n i=1

which implies that


n
1 X ′ ⊤  D E
∇L(W), W mm = ℓ γi S(hi ) · Xi⊤ S′ (hi )γi z⊤i , W mm
n i=1
n
1X ′  
= ℓi · trace (W mm )⊤ Xi⊤ S′ (hi )γi z⊤i
n i=1
n
(26)
1X ′ ⊤ ′
= ℓ · h̄ S (hi )γi
n i=1 i i
n
1X ′  ⊤ 
= ℓi · h̄i diag(si )γi − h̄⊤i si s⊤i γi .
n i=1

Here, let ℓi′ := ℓ′ (γ⊤i S(hi )), si = S(hi ) and the third equality uses trace ba⊤ = a⊤ b.

In order to move forward, we will establish the following result, with a focus on the equal score condition
(Assumption B.2): Let γ = γt≥2 be a constant, and let γ1 and h̄1 represent the largest indices of vectors γ and h̄

32
respectively. For any vector s that satisfies t∈[T ] st = 1 and st > 0, we aim to prove that h̄⊤ diag(s)γ − h̄⊤ ss⊤ γ > 0. To
P
demonstrate this, we proceed by writing the following:
T
X T
X T
X
h̄⊤ diag(s)γ − h̄⊤ ss⊤ γ = h̄t γt st − h̄t st γt s t
t=1 t=1 t=1
 T
  T

 X     X 
=  h̄1 γ1 s1 + γ
 h̄t st  − γ1 s1 + γ(1 − s1 )  h̄1 s1 +
  h̄t st 
t≥2 t≥2
T
X (27)
= h̄1 (γ1 − γ)s1 (1 − s1 ) − (γ1 − γ)s1 h̄t st
t≥2
 PT 
h̄t st 
= (γ1 − γ)(1 − s1 )s1  h̄1 − Pt≥2

T

t≥2 st
≥ (γ1 − γ)(1 − s1 )s1 ( h̄1 − max h̄t ).
t≥2

To proceed, define
γgap
i
= γiopti − max γit and h̄igap = h̄iopti − max h̄it .
t,opti t,opti

With these, we obtain


h̄⊤i diag(si )γi − h̄⊤i si s⊤i γi ≥ γgap
i
h̄igap (1 − siopti )siopti . (28)
Note that
h̄igap = min (xiopti − xit )⊤W mm zi ≥ 1,
t,opti

γgap
i
= min γiopti − γit > 0,
t,opti

siopti (1 − siopti ) > 0.

Hence, ( ! ! )
min min (xiopti − xit )⊤W mm zi · min γiopti − γit · siopti (1 − siopti ) > 0. (29)
i∈[n] t,opti t,opti

It follows from (28) and (29) that n o


min h̄⊤i diag(si )γi − h̄⊤i si s⊤i γi > 0. (30)
i∈[n]

Further, by Assumption A, ℓi′ < 0, ℓ′ is continuous and the domain is bounded, the maximum is attained and negative,
and thus
max ℓ′ (x) < 0. (31)
x

Hence, using (30) and (31) in (26), we obtain

∇L(W), W mm < 0. (32)

In the scenario that Assumption B.1 holds (all tokens are support), h̄t = x⊤it W mm zi is constant for all t ≥ 2. Hence,
following similar steps as in (27) completes the proof.

C.1.1 Proof of Theorem 3


It follows from Lemma 9 that under Assumption A, η ≤ 1/LW , and for any initialization W(0), the gradient descent
sequence W(k + 1) = W(k) − η∇L(W(k)) satisfies limk→∞ ∥∇L (W (k))∥2F = 0.
Further, it follows from Lemma 10 that ⟨∇L(W), W mm ⟩ < 0 for all W ∈ Rd×d . Hence, for any finite W,
⟨∇L(W), W mm ⟩ cannot be equal to zero. Therefore, there are no finite critical points W, for which ∇L(W) = 0
which contradicts Lemma 9. This implies that ∥W (k)∥ → ∞. ■

33
C.2 Global convergence under good initial gradient
To ensure global convergence, we identify an assumption that prevents GD from getting trapped at suboptimal tokens
that offer no scoring advantage compared to other choices. To establish a foundation for providing the convergence of
GD to the globally optimal solution W mm , we present the following definitions. For parameters µ ∈ (0, 1) and R > 0,
consider the following subset of the sphere and its associated cone:

µ
( * + )
W
S̄µ (W mm ) := W ∈ Rd×d (xiopti − xit )z⊤i , ≥ for all t , opti , i ∈ [n] , (33a)
∥W∥F ∥W mm ∥F
 
C̄µ,R (W mm ) := W ∈ S̄µ (W mm ) ∥W∥F ≥ R . (33b)

Note that the C̄µ,R (W mm ) definition is equivalent to the C̄µ,R definition in (4) with a change of variable µ ← ∥W mm ∥F · µ.

Lemma 11 Suppose Assumption A holds and let opt = (opti )ni=1 be the unique globally-optimal indices with W mm
denoting the Att-SVM solution. Define the margin Θ = 1/∥W mm ∥F . Let si = S(XiW zi ). For any µ > 0, there exists a
sufficiently large R̄µ = O(1/µ) (see (41)) such that:
L1. There is no stationary point within C̄µ,R̄µ (W mm ), where C̄µ,R̄µ (W mm ) is defined in (33).

L2. For all V ∈ S̄µ (W mm ) with ∥V∥F = ∥W mm ∥F and W ∈ C̄µ,R̄µ (W mm ), there exist dataset dependent constants
C, c > 0 such that
n n
1 X  D E 1 X 
C· 1 − siopti ≥ − ∇L(W), V ≥ c · µ · 1 − siopti > 0, (34a)
n i=1 n i=1
c Θ
* +
V ∇L(W)
− , ≥ · > 0, (34b)
∥V∥F ∥∇L(W)∥F C Ā
n
1 X 
∥∇L(W)∥F ≤ ĀC · 1 − siopti . (34c)
n i=1

Here, siopti = (S(XiW zi ))opti , Ā = maxi∈[n],t,τ∈[T ] ∥xit − xiτ ∥ ∥zi ∥, and Θ = 1/∥W mm ∥F .

Proof. For simplicity let R = R̄µ , W ∈ C̄µ,R (W mm ) and

(∥xit ∥ ∨ ∥xit − xiτ ∥) · ∥zi ∥


A= max . (35)
i∈[n],t,τ∈[T ] Θ

The following inequalities hold for all V ∈ S̄µ , ∥V∥F = ∥W mm ∥F and all i ∈ [n], t , opti :

A ≥ (xiopti − xit )⊤ V zi ≥ µ. (36)

To proceed, we write the gradient correlation following (15) and (26)


n
1X ′ D E
⟨∇L(W), V⟩ = ℓi · hi , S′ ( h̃i )γi , (37)
n i=1

where we denoted ℓi′ = ℓ′ (Yi · v⊤ Xi⊤ S( h̃i )), hi = Xi V zi , h̃i = XiW zi , si = S( h̃i ).
It follows from (35) that A ≥ maxi∈[n],t∈[T ] ∥hit ∥. Using (36), we can bound the softmax probabilities si = S( h̃i ) as
follows, for all i ∈ [n]:
X
S i := siτ ≤ T e−RµΘ siopti ≤ T e−RµΘ . (38)
τ,opti

Recall scores γit = Yi · v⊤ xit . Define the score gaps:


gap gap
γi = γiopti − max γit , γ̄i = γiopti − min γit , and Γ = sup |γit − γiτ |.
t,opti t,opti i∈[n],t,τ∈[T ]

34
Let us focus on a fixed datapoint i ∈ [n], assume (without losing generality) opti = 1, and drop subscripts i. Directly
applying Lemma 8, we obtain
T
X
h⊤ diag(s)γ − h⊤ ss⊤ γ − (h1 − ht )st (γ1 − γt ) ≤ 2ΓA(1 − s1 )2 .
t≥2

To proceed, let us upper/lower bound the gradient correlation. Since A ≥ h1 − ht ≥ µ > 0 from (36), setting
S := t,opti st = 1 − s1 , we find
P
X
A · S · γ̄gap ≥ (h1 − ht )st (γ1 − γt ) ≥ µ · S · γgap . (39)
t,opt

Next we show that S = 1 − s1 dominates (1 − s1 )2 = S 2 for large R. Specifically, we wish for


4 ΓA 2 µγgap
µS γgap /2 ≥ 2ΓA(1 − s1 )2 ⇐⇒ S ≥ S ⇐⇒ S ≤ . (40)
µγ gap 4ΓA
µγgap
Using (38), what we wish is ensured for all i ∈ [n], by guaranteeing T e−RµΘ ≤ 4ΓA . That is, by choosing
 4T ΓA 
 
1
R≥ log  gap  , (41)
µΘ µγmin

gap gap
where γmin = mini∈[n] γi is the global scalar corresponding to the worst case score gap over all inputs.
With the above choice of R, we guaranteed
X µ · S · γgap µ(1 − s1 )γgap
2A(1 − s1 ) · γ̄gap ≥ 2A · S · γ̄gap ≥ (h1 − ht )st (γ1 − γt ) ≥ ≥ ,
t,opt
2 2

via (40) and (39).


Since this holds over all inputs, going back to the gradient correlation (37) and averaging above over all inputs
i ∈ [n] and plugging back the indices i, we obtain the advertised bound
2A X ′ gap µ X ′ gap
−ℓi · S i · γ̄i ≥ − ⟨∇L(W), V⟩ ≥ −ℓ · S i · γi . (42)
n i∈[n] 2n i∈[n] i
gap
/ max be the min/max values negative loss derivative admits over the ball [−A, A] and note that maxi∈[n] γ̄i >0

Let −ℓmin
gap gap
and mini∈[n] γi > 0 are dataset dependent constants. Then, we declare the constants C = −2Aℓmax · maxi∈[n] γ̄i >

gap
0, c = −(1/2)ℓmin

· mini∈[n] γi > 0 to obtain the bound (34a).
The proof of (34b) and (34c) follows similarly as the proof of Lemma 13.

The following lemma shows that as π approaches zero, the negative gradient of the loss function at W ∈ C̄µ,R (W mm )
becomes more correlated with the max-margin solution (W mm ) than with W itself.
Lemma 12 Suppose Assumption A holds and let opt = (opti )ni=1 be the unique optimal tokens with W mm denoting the
SVM solution. Fix any µ > 0 (per Lemma 11). For any choice of π > 0, there exists R := Rπ ≥ R̄µ such that, for any
W ∈ C̄µ,R (W mm ), we have
W mm
* + * +
W
∇L(W), ≥ (1 + π) ∇L(W), .
∥W∥F ∥W mm ∥F
Here, C̄µ,R (W mm ) is the cone defined at (33).
Proof. Let W̄ = ∥W mm ∥F W/∥W∥F , hi = XiW̄ zi , h̄i = XiW mm zi , and si = S(XiW zi ). To establish the result, we will
prove that, for sufficiently large R and for any W ∈ C̄µ,R (W mm ):
* + n
W 1X ′
−∇L(W), =− ℓ · hi , S′ (XiW zi )γi
∥W∥F n i=1 i
n
1+πX ′ D W mm
E * +
≤− ℓi · h̄i , S′ (XiW zi )γi = (1 + π) −∇L(W), . (43)
n i=1 ∥W mm ∥F

35
Directly applying Lemma 8, for all V ∈ S̄µ with ∥V∥F = ∥W mm ∥F and h̃i = Xi V zi , we have found
X
h̃⊤i diag(si )γi − h̃⊤i si s⊤i γi − ( h̃i1 − h̃it )sit (γi1 − γit ) ≤ 2ΓA(1 − si1 )2 . (44)
t,opti

Recalling h̄i1 − h̄it ≥ 1, we note that t,opti sit (γi1 − γit ) ≤ t,opti ( h̄i1 − h̄it )sit (γi1 − γit ). Now plugging in h, h̄ in the
P P
bound above and assuming π ≤ 1 (w.l.o.g.), (43) is implied by the following stronger inequality
n
 
1 X ′  X 
− ℓi · 6ΓA(1 − si1 ) +
2
(hi1 − hit )sit (γi1 − γit )
n i=1 t,opt i
n
1+πX X
≤− ℓi′ · ( h̄i1 − h̄it )sit (γi1 − γit )
n i=1 t,opti
n
1+π X X
≤− ℓi′ · sit (γi1 − γit ).
n i=1 t,opti

First, we claim that 0.5π t∈opti sit (γi1 − γit ) ≥ 6ΓA(1 − si1 )2 for all i ∈ [n]. The proof of this claim directly follows the
P
gap
argument in Lemma 11, (namely following (38), (40), (41)) we have that 1 − si1 ≤ T e−RµΘ and γi1 − γit ≥ γmin for all
i ∈ [n]. This leads to the choice (for D0 ≥ 12)
 D0 · T ΓA 
 
1
R ≥ Rπ = log  . (45)
µΘ gap 
πγmin

We shall choose D0 sufficiently large such that Rπ ≥ R̄µ , where R̄µ is defined in Lemma 11.
Following this control over the perturbation term 6ΓA(1 − si1 )2 , to conclude with the result, what remains is proving
the comparison
n n
1X ′ X 1 + 0.5π X ′ X
− ℓi · (hi1 − hit )sit (γi1 − γit ) ≤ − ℓi · sit (γi1 − γit ). (46)
n i=1 t,opt
n i=1 t,opt
i i

π
Scenario 1: ∥W̄ − W mm ∥F ≤ ϵ = 4AΘ for some ϵ > 0. In this scenario, for any t , opti and i ∈ [n], we have
π
|hit − h̄it | = |x⊤it (W̄ − W mm )zi | ≤ AΘϵ = .
4
Consequently, we obtain
hi1 − hit ≤ h̄i1 − h̄it + 2AΘϵ = 1 + 0.5π.
Similarly, hi1 − hit ≥ 1 − 0.5π ≥ 0.5. Since all terms hi1 − hit , sit , γi1 − γit in (46) are nonnegative, we obtain (46).
π
Scenario 2: ∥W̄ − W mm ∥F ≥ ϵ = 4AΘ . Since W̄ is not max-margin solution, in this scenario, for some i ∈ [n],
ν = ν(ϵ) > 0, and τ , opti , we have that
hi1 − hiτ ≤ 1 − 2ν.

Here τ = arg maxτ,opti xiτW̄ zi denotes the nearest point to hi1 (along the W̄ direction). Recall that s = S(R̄h), where
R̄ = RΘ = ∥W∥F /∥W mm ∥F . To proceed, let hi := mint,opti hi1 − hit ,
n o n o
I := i ∈ [n] : hi ≤ 1 − 2ν , [n] − I := i ∈ [n] : 1 − 2ν < hi .
For all i ∈ [n] − I,
X X
(hi1 − hit )sit (γi1 − γit ) − (1 + 0.5π) sit (γi1 − γit )
t,opti t,opti
X
≤ (2A − (1 + 0.5π)) Γ sit
t,opti , hi1 −hit ≥1+ π2 (47)
π
≤ (2A − (1 + 0.5π)) ΓT e−R̄(1+ 2 )
π
≤ 2AΓT e−R̄(1+ 2 ) .

36
For all i ∈ I, split the tokens into two groups: Let Ni be the group of tokens obeying hi1 − hit ≤ 1 − ν and
N̄i := [T ] − {opti } − Ni be the rest of the neighbors. Observe that
eνR̄
P
t∈N̄ sit
P i ≤T = T e−R̄ν .
t,opti sit e2νR̄
gap gap
Using |hi1 − hit | ≤ 2A and γmin = mini∈[n] γi = mini∈[n] (γi1 − maxt,opti γit ), observe that
X 2ΓAT e−R̄ν X
(hi1 − hit )sit (γi1 − γit ) ≤ gap sit (γi1 − γit ).
t∈N̄i
γmin t,opti

Thus,
X X X
(hi1 − hit )sit (γi1 − γit ) = (hi1 − hit )sit (γi1 − γit ) + (hi1 − hit )sit (γi1 − γit )
t,opti t∈Ni t∈N̄i
X 2ΓAT e−R̄ν X
≤ (1 − ν)sit (γi1 − γit ) + gap sit (γi1 − γit )
t∈Ni
γmin t,opti
 
2ΓAT e−R̄ν  X
≤ 1 − ν + sit (γi1 − γit )

gap
γmin
 
t,opti
 
2ΓAT e−R̄ν  X
≤ 1 + sit (γi1 − γit ).

gap
γmin

t,opti

Hence, choosing
 
1  8ΓAT 
R≥ log  gap  (48)
νΘ γmin π
results in that
X  π X
(hi1 − hit )sit (γi1 − γit ) − 1 + sit (γi1 − γit )
t,opti
2 t,opt
i

 2ΓAT e−R̄ν π  X


 
≤  gap −  sit (γi1 − γit )
γmin 2 t,opt
i (49)
π X
≤− sit (γi1 − γit )
4 t,opt
i
π gap
≤ − γmin e−R̄(1−2ν) .
4T
−R̄(1−2ν)
Here, the last inequality follows from the fact that t,opti sit ≥ maxt,opti sit ≥ PT e e−R̄(hi1 −hit ) ≥ e−R̄(1−2ν) /T .
P
t=1
From Assumption A, we have cmin ≤ −ℓ′ ≤ cmax for some positive constants cmin and cmax . It follows from (47) and
(49) that
n
 
1 X ′  X X 
− ℓi ·  (hi1 − hit )sit (γi1 − γit ) − (1 + 0.5π)sit (γi1 − γit )
n i t,opt t,opt
i i
gap
π cmin πγmin −R̄(1−2ν)
≤ cmax 2AΓT Γe−R̄(1+ 2 ) − · e
nT 4
≤ 0.
Combing with (48), this is guaranteed by choosing
    
 1
  8ΓAT  1  8nΓAT 2 cmax 
log  gap  , ,
 
R ≥ max  log 

 νΘ γmin π (2ν + π/2)Θ
gap
cmin γmin π 
 

π
where ν = ν( 4AΘ ) depends only on π and global problem variables.
Combining this with the prior R choice (45) (by taking maximum), we conclude with the statement.

37
C.2.1 Proof of Theorem 4
The following theorem is a restatement of Theorem 4. Two minor differences are: (1) We state with the change of
variable µ ← ∥W mm ∥F · µ; see discussions below (33). (2) We also include (L3.) in the statement of the theorem.
Theorem 8 Suppose Assumption A on the loss function ℓ and Assumption C on the initial gradient hold.
L1. For any µ > 0, there exists R > 0 such that C̄µ,R (W mm ) defined in (33) does not contain any stationary points.
L2. Fix any µ ∈ (0, min(1, ι∥W mm ∥F /∥∇L(0)∥F ). Consider GD iterations with W(0) = 0, W(1) = −R∇L(0)/∥∇L(0)∥F ,
and W(k + 1) = W(k) − η∇L(W(k)) for η ≤ 1/LW , k ≥ 1, and R sufficiently large. If all iterates remain within C̄µ,R ,
mm
then limk→∞ ∥W(k)∥F = ∞ and limk→∞ ∥W(k)∥W(k)
F
= ∥WWmm ∥F .

L3. Assume η ≤ 1/LW and for all W ∈ C̄µ,R (W mm ) with sufficiently large R?
D E D E 2ηµ
min (xiopti − xit )z⊤i , W − η∇L(W) ≥ min (xiopti − xit )z⊤i , W − ∇L(W), W mm , (50)
i∈[n] i∈[n] ∥W mm ∥2F
then all GD iterations remain within C̄µ,R (W mm ).
Proof. Note that L1. is a direct corollary of Lemma 11. We proceed with the proof of L2. and L3.. We provide the
proof in four steps:
Step 1: C̄µ,R0µ (W mm) construction. Let us denote the initialization lower bound as R0µ := R, where R is given in the
Theorem 8’s statement. Consider an arbitrary value of ϵ ∈ (0, µ/2) and let 1/(1 + π) = 1 − ϵ. We additionally denote
Rϵ ← Rπ ∨ 1/2 where Rπ was defined in Lemma 12. At initialization W(0), we set ϵ = µ/2 to obtain R0µ = Rµ/2 .
We proceed to show µ ∈ (0, min(1, ι∥W mm ∥F /∥∇L(0)∥F ). It follows from Assumption C and under zero initialization
for GD (W(0) = 0) that
D E D E
(xiopti − xit )z⊤i , −∇L(W(0)) = (xiopti − xit )z⊤i , −∇L(0) ≥ ι > 0,
for some positive constant ι. Hence, for any initial step size η(0) > 0 and W(1) = −η(0)∇L(0),
η(0) D
* +
W(1) E
(xiopti − xit )zi ,

= (xiopti − xit )z⊤i , −∇L(0)
∥W(1)∥F ∥W(1)∥F
ιη(0) ι
≥ = (51)
∥W(1)∥F ∥∇L(0)∥F
µ
≥ .
∥W mm ∥F
Here, the last inequality follows from our choice of µ in the theorem statement, i.e.
ι∥W mm ∥F
!!
µ ∈ 0, min 1, . (52)
∥∇L(0)∥F
This µ choice induces the conic set C̄µ,R0µ (W mm ) with R0µ = Rµ/2 , where Rµ/2 was defined in Lemma 12. Now, given the
parameter µ satisfying (52), we can choose η(0) such that ∥W(1)∥F ≥ R0µ and W(1) ∈ C̄µ,R0µ (W mm ). To achieve this, since
W(0) = 0, we obtain
R0µ
η(0) = . (53)
∥∇L(0)∥F
Since by our definition, R0µ ← R, (53) gives W(1) in the theorem’s statement.
Step 2: There are no stationary points within C̄µ,R0µ (W mm). This step follows from L1.. Specifically, we can apply
Lemma 11 to find that: For all V, W ∈ S̄µ (W mm ) with ∥W∥F , 0 and ∥W∥F ≥ R0µ , we have that − ⟨V, ∇L(W)⟩ is strictly
positive.
Gradient correlation holds for large parameter norm. It follows from Lemma 12 that, there exists Rϵ ≥ R̄µ ∨ 1/2 such
that all W ∈ C̄µ,Rϵ (W mm ) satisfy
W mm
* + * +
W
−∇L(W), ≥ (1 − ϵ) −∇L(W), . (54)
∥W mm ∥F ∥W∥F
The following argument applies to a general ϵ ∈ (0, µ/2). However, at initialization W(0) = 0, we have set ϵ = µ/2 and
defined the initialization radius as R0µ = Rµ/2 . To proceed, we will prove the main statements (L2.) and (L3.) as follows.

38
• Proving L3.: In Step 3, we will assume Condition (50) to prove that gradient iterates remain within C̄µ,Rϵ (W mm ).
Concretely, for any ϵ ∈ (0, µ/2), we will show that after gradient descent enters the conic set C̄µ,Rϵ (W mm ) for the
first time, it will never leave the set under Condition (50) of the theorem statement and (54). In what follows, let
us denote kϵ to be the first time gradient descent enters C̄µ,Rϵ (W mm ). Note that for ϵ ← µ/2, kϵ = 0 i.e. the point
of initialization.

• Proving L2.: In Step 4, assuming iterates within C̄µ,Rϵ (W mm ), we will prove that the norm diverges (as a result
such kϵ is guaranteed to exist) and, additionally, the gradient updates asymptotically aligns with W mm .
Step 3 (Proof of L3.): Updates remain inside the cone C̄µ,Rϵ (W mm). Note that if W(k) ∈ C̄µ,Rϵ (W mm ) for all k ≥ 1,
the required condition in L2. holds, and we proceed to Step 4. In this step, we show L3.. Specifically, we show that
under Condition (50) and using (54), all iterates W(k) ∈ C̄µ,Rϵ (W mm ) remain within C̄µ,Rϵ (W mm ).
To proceed, by leveraging the results from Step 1 and Step 2, we demonstrate that the gradient iterates, with an
appropriate constant step size, starting from W(kϵ ) ∈ C̄µ,Rϵ (W mm ), remain within this set. We proceed by induction.
Suppose that the claim holds up to iteration k ≥ kϵ . This implies that W(k) ∈ C̄µ,Rϵ (W mm ). Hence, recalling C̄µ,Rϵ (W mm )
defined in (33), there exists scalar µ = µ(α) ∈ (0, 1) and Rϵ such that ∥W(k)∥F ≥ Rϵ , and
* +
W(k)
(xiopti − xit )zi ,

≥ µΘ,
∥W(k)∥F

where Θ = 1/∥W mm ∥F .
Let
W mm
* +
1
, −∇L(W(k)) =: ρ(k) > 0. (55a)
1 − ϵ ∥W mm ∥F

Using (50), we have

⊤ W(k + 1) η
* + * +
W(k)
(xiopti − xit )zi , = (xiopti − xit )zi ,

− ∇L(W(k))
∥W(k)∥F ∥W(k)∥F ∥W(k)∥F
(56)
2η(1 − ϵ)µΘρ(k)
≥ µΘ + .
∥W(k)∥F
From Lemma 11, we have ⟨∇L(W(k)), W(k)⟩ < 0 which implies that ∥W(k + 1)∥F ≥ ∥W(k)∥F . This together with Rϵ
definition and ∥W(k)∥F ≥ 1/2 implies that
1  
∥W(k + 1)∥F ≤ ∥W(k + 1)∥2F + ∥W(k)∥2F
2∥W(k)∥F
1  
= 2∥W(k)∥2F − 2η ⟨∇L(W(k)), W(k)⟩ + η2 ∥∇L(W(k))∥2F
2∥W(k)∥F
η
≤ ∥W(k)∥F − ⟨∇L(W(k)), W(k)⟩ + η2 ∥∇L(W(k))∥2F .
∥W(k)∥F
Thus,

∥W(k + 1)∥F η ∥∇L(W(k))∥2F


* +
W(k)
≤1− ∇L(W(k)), + η2
∥W(k)∥F ∥W(k)∥F ∥W(k)∥F ∥W(k)∥F
η W mm ∥∇L(W(k))∥2F
* +
≤1− ∇L(W(k)), + η2 (57)
(1 − ϵ)∥W(k)∥F ∥W mm ∥F ∥W(k)∥F
ηρ(k) η ∥∇L(W(k))∥F
2 2
≤1+ + =: C1 (ρ(k), η).
∥W(k)∥F ∥W(k)∥F
Here, the second inequality uses (54).

39
Now, it follows from (56) and (57) that
W(k + 1) 2η(1 − ϵ)µΘρ(k)
* + !
1
min (xiopti − xit )zi ,

≥ µΘ +
t,opti , i∈[n] ∥W(k + 1)∥F C1 (ρ(k), η) ∥W(k)∥F
ηµΘ − ϵ) − 1 ρ(k) ∥∇L(W(k))∥2F 
  
2(1
= µΘ + η



C1 (ρ(k), η)

∥W(k)∥F ∥W(k)∥F

(58)
ηµΘ ∥∇L(W(k))∥2F 
 
 (1 − 2ϵ)ρ(k)
= µΘ + −η
C1 (ρ(k), η)
 
∥W(k)∥F ∥W(k)∥F
≥ µΘ,
where the last inequality uses our choice of stepsize η ≤ 1/LW in Theorem 4’s statement. Specifically, we need η
to be small to ensure the last inequality. We will guarantee this by choosing a proper Rϵ in Lemma 12. Specifically,
Lemma 12 leaves the choice of D0 in Rϵ lower bound of (45) open (it can always be chosen larger). Here, by choosing
D0 ≳ 1/LW will ensure η ≤ 1/LW works well.
 c Θ 1 R0µ Θ/2
η≤ 1−µ µ e
C Ā ĀCT
1 − 2ϵ cµ Θ 1 R0µ Θ/2
≤ e (59)
1 − ϵ C Ā ĀCT
ρ(k)
.

≤ 1 − 2ϵ
∥∇L(W(k))∥2F
Here, the first inequality follows since ϵ ∈ (0, µ/2) (as seen in Step 2). Also, µ < 1 implies that 1 − µ > 0, we obtain
η > 0. The last inequality is obtained from Lemma 11:
ρ(k) W mm cµ Θ
* +
1 ∇L(W(k)) 1
=− , ≥ · · ,
∥∇L(W(k))∥F 1 − ϵ ∥∇L(W(k))∥F ∥W ∥F mm 1 − ϵ C Ā
1 1 1
≥ ≥
ĀCT e−Rµ Θ/2
0

∥∇L(W(k))∥F ĀC · 1 Pn 1 − siopt
n i=1 i

for some data dependent constrants c, C, Ā = maxi∈[n],t,τ∈[T ] ∥(xit − xiτ )∥ ∥zi ∥, and Θ = 1/∥W mm ∥F .
The remainder of the proof of this step is identical to (86)–(90), with the replacement of C0 by D0 and the tracking
of changes. Specifically, Lemma 12 leaves the choice of D0 in Rϵ lower bound of (45) open (it can always be chosen
larger). Hence, for sufficiently large D0 , we have
1  c Θ 1 R0µ Θ/2
η≤ ≤ 1−µ µ e . (60)
LW C Ā ĀCT
This implies (58) and W(k + 1) ∈ C̄µ,Rϵ (W mm ).
Step 4 (Proof of L2.): W(k) and W mm perfectly align over time. By theorem statement (alternatively via Step 3),
we have that all iterates remain within the initial conic set i.e. W(k) ∈ C̄µ,R0µ (W mm ) for all k ≥ 0. Note that it follows
from Lemma 11 that ⟨∇L(W), W mm /∥W mm ∥F ⟩ < 0, for any finite W ∈ C̄µ,R0µ (W mm ). Hence, there are no finite critical
points W ∈ C̄µ,R0µ (W mm ), for which ∇L(W) = 0. Now, based on Lemma 9, which guarantees that ∇L(W(k)) → 0, this
implies that ∥W (k)∥ → ∞. Consequently, for any choice of ϵ ∈ (0, µ/2) there is an iteration kϵ such that, for all k ≥ kϵ ,
W(k) ∈ C̄µ,Rϵ (W mm ). Once within C̄µ,Rϵ (W mm ), multiplying both sides of (54) by the stepsize η and using the gradient
descent update, we get
W mm
* + * +
W(k)
W(k + 1) − W(k), ≥ (1 − ϵ) W(k + 1) − W(k),
∥W mm ∥F ∥W(k)∥F
(1 − ϵ)  
= ∥W(k + 1)∥2F − ∥W(k)∥2F − ∥W(k + 1) − W(k)∥2F
2∥W(k)∥F
!
1  
≥ (1 − ϵ) ∥W(k + 1)∥F − ∥W(k)∥F − ∥W(k + 1) − W(k)∥F
2 2 2
2∥W(k)∥F
 
≥ (1 − ϵ) ∥W(k + 1)∥F − ∥W(k)∥F − ∥W(k + 1) − W(k)∥2F
 
≥ (1 − ϵ) ∥W(k + 1)∥F − ∥W(k)∥F − 2η (L(W(k)) − L(W(k + 1))) .

40
Here, the second inequality is obtained from ∥W(k)∥F ≥ 1/2; the third inequality follows since for any a, b > 0, we
have (a2 − b2 )/(2b) − (a − b) ≥ 0; and the last inequality uses Lemma 9.
Summing the above inequality over k ≥ kϵ gives

W mm C(ϵ, η)
* +
W(k)
, ≥1−ϵ+ , W(k) ∈ C̄µ,Rϵ (W mm ),
∥W(k)∥F ∥W mm ∥F ∥W(k)∥F
where L⋆ ≤ L (W (k)) for all k ≥ kϵ , and
W mm
* +
C(ϵ, η) = W(kϵ ), − (1 − ϵ)∥W(kϵ )∥F − 2η(1 − ϵ)(L(W(kϵ )) − L⋆ ).
∥W mm ∥F
Consequently,
W mm
* +
W(k)
lim inf , ≥ 1 − ϵ, W(k) ∈ C̄µ,Rϵ (W mm ).
k→∞ ∥W(k)∥F ∥W mm ∥F
Since ϵ ∈ (0, µ/2) is arbitrary, this implies W(k)/∥W(k)∥F → W mm /∥W mm ∥F .

D Local Convergence of Gradient Descent


To provide a basis for discussing local convergence of GD, we establish a cone centered around Wαmm using the following
construction. For parameters µ ∈ (0, 1) and R > 0, we define Cµ,R (Wαmm ) as the set of matrices W ∈ Rd×d such that
∥W∥F ≥ R and the correlation coefficient between W and Wαmm is at least 1 − µ:
Wαmm
( * + )
W
mm
Sµ (Wα ) := W ∈ R d×d
: , ≥1−µ , (61a)
∥W∥F ∥Wαmm ∥F
n o
Cµ,R (Wαmm ) := Sµ (Wαmm ) ∩ W ∈ Rd×d : ∥W∥F ≥ R . (61b)
Lemma 13 Suppose Assumption A on the loss function ℓ holds, and let α = (αi )ni=1 be locally optimal tokens according
to Definition 2. Let W mm = Wαmm denote the SVM solution obtained via (Att-SVM) by applying the Frobenius norm
and replacing (opti )ni=1 with α = (αi )ni=1 . There exists a scalar µ = µ(α) > 0 such that for sufficiently large R̄µ :
L1. There is no stationary point within Cµ,R̄µ (W mm ).
L2. For all V ∈ Sµ (W mm ) with ∥V∥F = ∥W mm ∥F and W ∈ Cµ,R̄µ (W mm ), there exist dataset dependent constants
C, c > 0 such that
n n
1X D E 1X
1 − siαi > 0,
 
C· 1 − siαi ≥ − ∇L(W), V ≥ c · (62a)
n i=1 n i=1
n
1X
1 − siαi ,

∥∇L(W)∥F ≤ ĀC · (62b)
n i=1
c Θ
* +
V ∇L(W)
− , ≥ · > 0. (62c)
∥V∥F ∥∇L(W)∥F C Ā
Here, siαi = (S(XiW zi ))αi , Ā = maxi∈[n],t,τ∈[T ] ∥(xit − xiτ )∥ ∥zi ∥, and Θ = 1/∥W mm ∥F .
Proof. Let R = R̄µ , (Ti )ni=1 be the set of all support indices per Definition 2. Let T̄i = [T ] − Ti − {αi } be the
non-support indices. Let
Θ = 1/∥W mm ∥F ,
1
δ = min min (xit − xiτ )⊤W mm zi ,
2 i∈[n] t∈Ti ,τ∈T̄i
∥xit z⊤i ∥F (63)
A = max ,
i∈[n],t∈[T ] Θ
!2
1 min(0.5, δ)
µ ≤ µ(δ) = .
8 A

41
Since W mm is the max-margin model ensuring (xiαi − xit )⊤W mm zi ≥ 1, the following inequalities hold for all W ∈
Sµ (W mm ), ∥W∥F = ∥W mm ∥F and all i ∈ [n], t ∈ Ti , τ ∈ T̄i :

(xit − xiτ )⊤W zi ≥ δ > 0,


(xiαi − xiτ )⊤W zi ≥ 1 + δ, (64)
3 1
≥ (xiαi − xit )⊤W zi ≥ .
2 2

Here, we used ∥W − W mm ∥2F /∥W mm ∥2F ≤ 2µ which implies ∥W − W mm ∥F ≤


p
2µ/Θ.
To proceed, we write the gradient correlation following (15) and (27)
n
1X ′ ⊤ ′
⟨∇L(W), V⟩ = ℓ · h S ( h̃i )γi , (65)
n i=1 i i

where we denoted ℓi′ = ℓ′ (Yi · v⊤ Xi⊤ S( h̃i )), hi = Xi V zi , h̃i = XiW zi , and si = S( h̃i ).
Using (64), for all t ∈ Ti , τ ∈ T̄i , for all W ∈ Cµ,R (W mm ), we have that

h̃it − h̃iτ ≥ RΘδ,


h̃iαi − h̃iτ ≥ RΘ(1 + δ),
h̃iαi − h̃it ≥ RΘ/2.

Consequently, we can bound the softmax probabilities si = S( h̃i ) over non-support indices as follows: For all i ∈ [n]
and any ti ∈ Ti
X
S i := siτ ≤ T e−RΘ/2 siαi ≤ T e−RΘ/2 , (66a)
τ∈Ti
X
Qi := siτ ≤ T e−RΘδ siti ≤ T e−RΘδ S i . (66b)
τ∈T̄i

Recall scores γit = Yi · v⊤ xit . Define the score gaps over support indices:
gap gap
γi = γiαi − max γit and γ̄i = γiαi − min γit .
t∈Ti t∈Ti

It follows from (63) that

∥xit z⊤i ∥F
A = max ≥ max ∥hit ∥.
i∈[n],t∈[T ] Θ i∈[n],t∈[T ]

Define the α-dependent global scalar Γ = supi∈[n],t,τ∈[T ] |γit − γiτ |.


Let us focus on a fixed datapoint i ∈ [n], assume (without losing generality) αi = 1, and drop subscripts i. Directly
applying Lemma 8, we obtain
T
X
h⊤ diag(s)γ − h⊤ ss⊤ γ − (h1 − ht )st (γ1 − γt ) ≤ 2ΓA(1 − s1 )2 .
t≥2

To proceed, let us decouple the non-support indices within Tt≥2 (h1 − ht )st (γ1 − γt ) via
P

X
(h1 − ht )st (γ1 − γt ) ≤ 2QΓA.
t∈T̄

Aggregating these, we found


X
h⊤ diag(s)γ − h⊤ ss⊤ γ − (h1 − ht )st (γ1 − γt ) ≤ 2ΓA((1 − s1 )2 + Q). (67)
t∈T

42
To proceed, let us upper/lower bound the gradient correlation. We use two bounds depending on V ∈ Sµ (W mm ) (Case
1) or general V ∈ Rd×d (Case 2).
• Case 1: V ∈ Sµ (W mm). Since 1.5 ≥ h1 − ht ≥ 0.5 following (64), we find
X
1.5 · S · γ̄gap ≥ (h1 − ht )st (γ1 − γt ) ≥ 0.5 · S · γgap ,
t∈T

where recall the definition of S (having dropped subscripts) in (66a).


• Case 2: V ∈ Rd×d and ∥V∥F = ∥W mm∥F . Define Ā = maxi∈[n],t,τ∈[T ] ∥xit − xiτ ∥ ∥zi ∥. For any ∥V∥F = ∥W mm ∥, we use
the fact that

∥h1 − ht ∥ ≤ ∥(xit − xiτ )z⊤i ∥F · ∥V∥F ≤ .
Θ

Note that by definition Θ ≥ 1. To proceed, we can upper bound

Ā X
· S · γ̄gap ≥ (h1 − ht )st (γ1 − γt ). (68)
Θ t∈T

Next we claim that for both cases, S dominates ((1 − s1 )2 + Q) for large R. Specifically, we wish for
S · γgap ΓA
≥ 4ΓA max((1 − s1 )2 , Q) ⇐⇒ S ≥ 16 gap max((1 − s1 )2 , Q). (69)
4 γ

Now choose R ≥ δ−1 log(T )/Θ to ensure Q ≤ S since Q ≤ T e−RΘδ S from (66a). Consequently

(1 − s1 )2 = (Q + S )2 ≤ 4S 2 ≤ 4S T e−RΘ/2 .

Combining these, what we wish is ensured by guaranteeing


ΓA
S ≥ 16 max(4S T e−RΘ/2 , T e−RΘδ S ). (70)
γgap
This in turn is ensured for all inputs i ∈ [n] by choosing

max(2, δ−1 )  64T ΓA 


 
R≥ log  gap  , (71)
Θ γmin
gap gap
where γmin = mini∈[n] γi is the global scalar which is the worst case score gap over all inputs.
• Case 1: V ∈ Sµ (W mm). With the above choice of R, we guaranteed

S · γgap (1 − s1 )γgap
2(1 − s1 ) · γ̄gap ≥ 2 · S · γ̄gap ≥ h⊤ diag(s)γ − h⊤ ss⊤ γ ≥ ≥ .
4 8
via (69) and (67).
Since this holds over all inputs, going back to the gradient correlation (65) and averaging above over all inputs
i ∈ [n] and plugging back the indices i, we obtain the advertised bound
2X ′ gap 1 X ′ gap
−ℓi · S i · γ̄i ≥ − ⟨∇L(W), V⟩ ≥ −ℓ · S i · γi . (72)
n i∈[n] 8n i∈[n] i

gap
/ max be the min/max values negative loss derivative admits over the ball [−A, A] and note that maxi∈[n] γ̄i >0

Let −ℓmin
gap gap
and mini∈[n] γi > 0 are dataset dependent constants. Then, we declare the constants C = −2ℓmax · maxi∈[n] γ̄i > 0, c =

gap

−(1/8)ℓmin · mini∈[n] γi > 0 to obtain the bound (62a).
• Case 2: V ∈ Rd×d and ∥V∥F = ∥W mm∥F . Next, we show (62b) and (62c). For any V ∈ Rd×d satisfying ∥V∥F =
∥W mm ∥F , using (68) and the choice of R in (71) similarly guarantees

2Ā
(1 − s1 )γ̄gap ≥ h⊤ diag(s)γ − h⊤ ss⊤ γ,
Θ

43
for fixed input. Going back to the gradient correlation (65) and averaging above over all inputs i ∈ [n], with the same
definition of C > 0, we obtain

ĀC X
(1 − siαi ) ≥ − ⟨∇L(W), V⟩ . (73)
Θn i∈[n]

∥W mm ∥F
To proceed, since (73) holds for any V ∈ Rd×d , we observe that when setting V = ∥∇L(W)∥F · ∇L(W), this implies that

ĀC X
⟨∇L(W), V⟩ = ∥∇L(W)∥F · ∥W mm ∥F ≤ (1 − siαi ).
Θn i∈[n]

Simplifying Θ = 1/∥W mm ∥F on both sides gives (62b).


Combining the above inequality with (72), we obtain that for all V, W ∈ Sµ (W mm )
* +
V ∇L(W) cΘ
− , ≥ ,
∥V∥F ∥∇L(W)∥F C Ā
which gives (62c).

Lemma 14 Suppose Assumption A on the loss function ℓ holds, and let α = (αi )ni=1 be locally optimal tokens according
to Definition 2. Let W mm = Wαmm denote the SVM solution obtained via (Att-SVM) by replacing (opti )ni=1 with
α = (αi )ni=1 . Let µ = µ(α) > 0 and R̄µ be defined as in Lemma 13. For any choice of π > 0, there exists Rπ ≥ R̄µ such
that, for any W ∈ Cµ,Rπ (W mm ), we have

W mm
* + * +
W
∇L(W), ≥ (1 + π) ∇L(W), .
∥W∥F ∥W mm ∥F

Proof. Let R = Rπ , W̄ = ∥W mm ∥F W/∥W∥F , hi = XiW̄ zi , and h̄i = XiW mm zi . To establish the result, we will prove that,
for sufficiently large R and for any W ∈ Cµ,R (W mm ):
* + n
W 1X ′
−∇L(W), =− ℓ · hi , S′ (XiW zi )γi
∥W∥F n i=1 i
n
1+πX ′ D W mm
E * +
≤− ℓ · h̄i , S (XiW zi )γi = (1 + π) −∇L(W),

. (74)
n i=1 i ∥W mm ∥F

Following (67), for all W ∈ Sµ (W mm ) with ∥W∥F = ∥W mm ∥F , h̃ = XW z, and s = S( h̃), we have found
X
h̃⊤i diag(si )γi − h̃⊤i si s⊤i γi − ( h̃i1 − h̃it )sit (γi1 − γit ) ≤ 2ΓA((1 − si1 )2 + Qi ), (75)
t∈Ti

where Ti is the set of support indices.


Plugging in h, h̄ in the bound above and assuming π ≤ 1 (w.l.o.g.), (74) is implied by the following stronger
inequality
n
 
1 X ′  X 
− ℓi · 6ΓA((1 − si1 ) + Qi ) +
2
(hi1 − hit )sit (γi1 − γit )
n i=1 t∈T i
n
1+πX X
≤− ℓi′ · ( h̄i1 − h̄it )sit (γi1 − γit )
n i=1 t∈Ti
n
1+π X X
=− ℓi′ · sit (γi1 − γit ).
n i=1 t∈Ti

44
First, we claim that 0.5π t∈Ti sit (γi1 − γit ) ≥ 6ΓA((1 − si1 )2 + Qi ) for all i ∈ [n]. The proof of this claim directly follows
P
the earlier argument, namely, following (69), (70), and (71) which leads to the choice

max(2, δ−1 )  C0 · T ΓA 


 
R≥ log  , (76)
Θ gap 
πγmin

for some constant C0 > 0. Using (71), we choose C0 ≥ 64π to guarantee R = Rπ ≥ R̄µ .
Following this control over the perturbation term 6ΓA((1 − si1 )2 + Qi ), to conclude with the result, what remains is
proving the comparison
n n
1X ′ X 1 + 0.5π X ′ X
− ℓi · (hi1 − hit )sit (γi1 − γit ) ≤ − ℓi · sit (γi1 − γit ). (77)
n i=1 t∈T
n i=1 t∈T
i i

To proceed, we split the problem into two scenarios.


π
Scenario 1: ∥W̄ − W mm ∥F ≤ ϵ = 4AΘ for some ϵ > 0. In this scenario, for any t ∈ Ti and i ∈ [n], we have
π
|hit − h̄it | = |x⊤it (W̄ − W mm )zit | ≤ AΘϵ = .
4
Consequently, we obtain
hi1 − hit ≤ h̄i1 − h̄it + 2AΘϵ = 1 + 0.5π.
Similarly, hi1 − hit ≥ 1 − 0.5π ≥ 0.5. Since all terms hi1 − hit , sit , γi1 − γit in (77) are nonnegative and (hi1 − hit )sit (γi1 −
γit ) ≤ (1 + 0.5π)sit (γi1 − γit ), above implies the desired result in (77).
π
Scenario 2: ∥W̄ − W mm ∥F ≥ ϵ = 4AΘ . Since W̄ is not (locally) max-margin, in this scenario, for some i ∈ [n],
ν = ν(ϵ) > 0, and τ ∈ Ti , we have that

hi1 − hiτ ≤ 1 − 2ν.

Here τ = arg maxτ∈Ti xiτW̄ zi denotes the nearest point to hi1 (along the W̄ direction). Note that a non-neighbor t ∈ T̄i
cannot be nearest because W̄ ∈ Sµ (W mm ) and (64) holds. Recall that si = S(R̄hi ) where R̄ = ∥W∥F Θ ≥ RΘ. To proceed,
let hi := mint∈Ti hi1 − hit ,
n o n o
I := i ∈ [n] : hi ≤ 1 − 2ν , [n] − I := i ∈ [n] : 1 − 2ν < hi .

For all i ∈ [n] − I,


X X
(hi1 − hit )sit (γi1 − γit ) − (1 + 0.5π) sit (γi1 − γit )
t∈Ti t∈Ti
X
≤ (2A − (1 + 0.5π)) Γ sit
t∈Ti , hi1 −hit ≥1+ π2 (78)
−R̄(1+ π2 )
≤ (2A − (1 + 0.5π)) ΓT e
π
≤ 2AΓT e−R̄(1+ 2 ) .

For all i ∈ I, split the tokens into two groups: Let Ni be the group of tokens obeying hi1 − hit ≤ 1 − ν and Ti − Ni
be the rest of the neighbors. Observe that

eνR̄
P
t∈T −N sit
P i i ≤T = T e−R̄ν .
t∈Ti sit e2νR̄
gap gap
Using |hi1 − hit | ≤ 2A = 2 maxi∈[n],t∈[T ] ∥kit ∥/Θ and γmin = mini∈[n] γi = mini∈[n] (γi1 − maxt∈Ti γit ), observe that

X 2ΓAT e−R̄ν X
(hi1 − hit )sit (γi1 − γit ) ≤ gap sit (γi1 − γit ).
t∈Ti −Ni
γmin t∈Ti

45
Thus,
X X X
(hi1 − hit )sit (γi1 − γit ) = (hi1 − hit )sit (γi1 − γit ) + (hi1 − hit )sit (γi1 − γit )
t∈Ti t∈Ni t∈Ti −Ni
X 2ΓAT e−R̄ν X
≤ (1 − ν)sit (γi1 − γit ) + gap sit (γi1 − γit )
t∈Ni
γmin t∈Ti
 
2ΓAT e−R̄ν  X
≤ 1 − ν + sit (γi1 − γit )

gap
γmin

t∈Ti
 
2ΓAT e−R̄ν  X
≤ 1 + sit (γi1 − γit ).

gap
γmin

t∈Ti

Hence, choosing
 
1  8ΓAT 
R≥ log  gap  (79)
νΘ γmin π

results in that
X π X
(hi1 − hit )sit (γi1 − γit ) − (1 + ) sit (γi1 − γit )
t∈Ti
2 t∈T
i

 2ΓAT e−R̄ν π  X


 
≤  gap −  sit (γi1 − γit )
γmin 2 t∈T
i (80)
πX
≤− sit (γi1 − γit )
4 t∈T
i
π gap
≤ − γmin e−R̄(1−2ν) .
4T
−R̄(1−2ν)
Here, the last inequality follows from the fact that t∈Ti sit ≥ maxt∈Ti sit ≥ PT e e−R̄(hi1 −hit ) ≥ e−R̄(1−2ν) /T .
P
t=1
From Assumption A, we have cmin ≤ −ℓ′ ≤ cmax for some positive constants cmin and cmax . It follows from (78) and
(80) that
n
 
1 X ′ X X 
− ℓi ·  (hi1 −it )sit (γi1 − γit ) − (1 + 0.5π)sit (γi1 − γit )
n i t∈T t∈T
i i
gap
−R̄(1+ π2 ) cmin πγmin −R̄(1−2ν)
≤ cmax 2AΓT Γe − · e
nT 4
≤ 0.

Combing with (79), this is guaranteed by choosing


    
 1
  8ΓAT  1  8nΓAT 2 cmax 
 νΘ log  γgap π  , (2ν + π/2)Θ log  c γgap π  ,
 
R ≥ max 



min min min

π
where ν = ν( 4AΘ ) depends only on π and global problem variables.
Combining this with the prior R choice (76) (by taking maximum), we conclude with the statement.

D.1 Proof of Theorem 5


The proof of this theorem follows the proof of [TLZO23, Theorem 3]. Let us denote the initialization lower bound
as R0µ := R, where R is given in the Theorem 5’s statement. Consider an arbitrary value of ϵ ∈ (0, µ/2) and let
1/(1 + π) = 1 − ϵ. We additionally denote Rϵ ← Rπ ∨ 1/2 where Rπ was defined in Lemma 14. At initialization W(0),

46
we set ϵ = µ/2 to obtain R0µ = Rµ/2 , and provide the proof in four steps:
Step 1: There are no stationary points within Cµ,R0µ (W mm). We begin by proving that there are no stationary points
within Cµ,R0µ (W mm ). Let (Ti )ni=1 denote the sets of support indices as defined in Definition 2. We define T̄i = [T ]−Ti −{αi }
as the tokens that are non-support indices. Additionally, let µ be defined as in (63). Then, since R0µ ≥ R̄µ per Lemma
14, we can apply Lemma 13 to find that: For all V, W ∈ Sµ (W mm ) with ∥W∥F , 0 and ∥W∥F ≥ R0µ , we have that
− ⟨V, ∇L(W)⟩ is strictly positive.
Step 2: It follows from Lemma 14 that, there exists Rϵ ≥ R̄µ ∨ 1/2 such that all W ∈ Cµ,Rϵ (W mm ) satisfy
W mm
* + * +
W
−∇L(W), ≥ (1 − ϵ) −∇L(W), . (81)
∥W mm ∥F ∥W∥F
The argument below applies to a general ϵ ∈ (0, µ/2). However, at initialization W(0), we set ϵ = µ/2 and, recalling
above, initialization lower bound was defined as R0µ := Rµ/2 . To proceed, for any ϵ ∈ (0, µ/2), we will show that after
gradient descent enters the conic set Cµ,Rϵ (W mm ) for the first time, it will never leave the set. Let tϵ be the first time
gradient descent enters Cµ,Rϵ (W mm ). In Step 4, we will prove that such tϵ is guaranteed to exist. Additionally, for
ϵ ← µ/2, note that tϵ = 0 i.e. the point of initialization.
Step 3: Updates remain inside the cone Cµ,Rϵ (W mm). By leveraging the results from Step 1 and Step 2, we demonstrate
that the gradient iterates, with an appropriate constant step size, starting from W(kϵ ) ∈ Cµ,Rϵ (W mm ), remain within this
cone.
We proceed by induction. Suppose that the claim holds up to iteration k ≥ kϵ . This implies that W(k) ∈ Cµ,Rϵ (W mm ).
Hence, recalling cone definition, there exists scalar µ = µ(α) ∈ (0, 1) and R such that ∥W(k)∥F ≥ R, and
W mm
* +
W(k)
, ≥ 1 − µ.
∥W(k)∥F ∥W mm ∥F
For all k ≥ 1, let
W mm
* +
1
ρ(k) := − ∇L(W(k)), . (82)
1−ϵ ∥W mm ∥F
Note that ρ(k) > 0 due to Step 1. This together with the gradient descent update rule gives
W(k + 1) W mm η W mm
* + * +
W(k)
, = − ∇L(W(k)),
∥W(k)∥F ∥W mm ∥F ∥W(k)∥F ∥W(k)∥F ∥W mm ∥F
η W mm
* +
≥1−µ− ∇L(W(k)), (83a)
∥W(k)∥F ∥W mm ∥F
ηρ(k)(1 − ϵ)
≥1−µ+ .
∥W(k)∥F
Note that from Lemma 13, we have ⟨∇L(W(k)), W(k)⟩ < 0 which implies that ∥W(k + 1)∥F ≥ ∥W(k)∥F . This together
with Rϵ definition and ∥W(k)∥F ≥ 1/2 implies that
1  
∥W(k + 1)∥F ≤ ∥W(k + 1)∥2F + ∥W(k)∥2F
2∥W(k)∥F
1  
= 2∥W(k)∥2F − 2η ⟨∇L(W(k)), W(k)⟩ + η2 ∥∇L(W(k))∥2F
2∥W(k)∥F
η
≤ ∥W(k)∥F − ⟨∇L(W(k)), W(k)⟩ + η2 ∥∇L(W(k))∥2F ,
∥W(k)∥F
which gives
∥W(k + 1)∥F η ∥∇L(W(k))∥2
* +
W(k)
≤1− ∇L(W(k)), + η2
∥W(k)∥F ∥W(k)∥F ∥W(k)∥F ∥W(k)∥F
η mm
∥∇L(W(k))∥2
* +
W
≤1− ∇L(W(k)), + η2 (83b)
(1 − ϵ)∥W(k)∥F mm
∥W ∥F ∥W(k)∥F
ηρ(k) η ∥∇L(W(k))∥
2 2
≤1+ + =: C1 (ρ(k), η).
∥W(k)∥F ∥W(k)∥F

47
Here, the second inequality follows from (81) and (82).
Now, it follows from (83a) and (83b) that

W(k + 1) W mm ηρ(k)(1 − ϵ)
* + !
1
, ≥ 1−µ+
∥W(k + 1)∥ ∥W mm ∥ C1 (ρ(k), η) ∥W(k)∥F
ηρ(k)(1 − ϵ)
!
1
=1−µ+ (1 − µ)(1 − C1 (ρ(k), η)) +
C1 (ρ(k), η) ∥W(k)∥F
η ρ(k) η∥∇L(W(k))∥2 ρ(k)(1 − ϵ)
!
(84)
=1−µ+ (µ − 1)( + )+
C1 (ρ(k), η) ∥W(k)∥F ∥W(k)∥F ∥W(k)∥F
η ρ(k)(µ − ϵ) ∥∇L(W(k))∥2
!
=1−µ+ − η(1 − µ)
C1 (ρ(k), η) ∥W(k)∥F ∥W(k)∥F
≥ 1 − µ,

where the last inequality uses our choice of stepsize η ≤ 1/LW in Theorem 5’s statement. Specifically, we need η
to be small to ensure the last inequality. We will guarantee this by choosing a proper Rϵ in Lemma 14. Specifically,
Lemma 14 leaves the choice of C0 in Rϵ lower bound of (76) open (it can always be chosen larger). Here, by choosing
C0 ≳ 1/LW will ensure η ≤ 1/LW works well.
µ c Θ 1 R0µ Θ/2
η≤ e
2(1 − µ)(1 − µ2 ) C Ā ĀCT
µ−ϵ 1 c Θ 1 R0µ Θ/2 (µ − ϵ) ρ(k)
≤ · · · · e ≤ . (85)
1 − µ 1 − ϵ C Ā ĀCT 1 − µ ∥∇L(W(k))∥2F

Here, the first inequality uses our choice of ϵ ∈ (0, µ/2) (see Step 2), and the last inequality is obtained from Lemma 13
since
ρ(k) W mm c Θ
* +
1 ∇L(W(k)) 1
=− , ≥ · · ,
∥∇L(W(k))∥F 1 − ϵ ∥∇L(W(k))∥F ∥W ∥F mm 1 − ϵ C Ā
1 1 1
≥ ≥
ĀC · 1n ni=1 1 − siαi ĀCT e−Rµ Θ/2
0
∥∇L(W(k))∥F
P

for some data dependent constrants c and C, Ā = maxi∈[n],t,τ∈[T ] ∥(xit − xiτ )∥ ∥zi ∥, and Θ = 1/∥W mm ∥F .
Next, we will demonstrate that the choice of η in (85) does indeed meet our step size condition as stated in the
theorem, i.e., η ≤ 1/LW . Recall that 1/(1 + π) = 1 − ϵ, which implies that π = ϵ/(1 − ϵ). Combining this with (76), we
obtain:
max(2, δ−1 )  C0 T ΓA 
 
Rπ ≥ log   , where C0 ≥ 64π. (86)
Θ gap 
πγmin
max(2, δ−1 )  (1 − ϵ)C0 T ΓA  ϵ
 
⇒ Rϵ ≥ log   , where C0 ≥ 64 . (87)
Θ gap
ϵγmin 1−ϵ

On the other hand, at the initialization, we have ϵ = µ/2 which implies that

max(2, δ−1 )  (2 − µ)C0 T ΓA  µ


 
0
Rµ ≥ log   , where C0 ≥ 64 . (88)
Θ gap
µγmin 2(1 − µ2 )

In the following, we will determine a lower bound on C0 such that our step size condition in Theorem 5’s statement, i.e.,
η ≤ 1/LW , is satisfied. Note that for the choice of η in (85) to meet the condition η ≤ 1/LW , the following condition
must hold:
µ 1 2−µ
!
1 1 R0µ Θ/2 2
≤ e 0
⇒ Rµ ≥ log C2 T . (89)
LW (2 − µ) C2 T Θ LW µ
2 2
where C2 = (1 − µ) ĀΘcC .

48
This together with (88) implies that
gap
C0 ΓA  (1 − µ)C2 γmin 64µ 
 
C2
≥ (1 − µ) ⇒ C 0 ≥ max ,  . (90)
gap
γmin ΓA 2 − µ

LW LW

Therefore, with this lower bound on C0 , the step size bound in (85) is sufficiently large to ensure that η ≤ 1/LW
guarantees (84).
Hence, it follows from (84) that W(k + 1) ∈ Cµ,Rϵ (W mm ).
Step 4: The correlation of W(k) and W mm increases over k. From Step 3, we have that all iterates remain
within the initial conic set i.e. W(k) ∈ Cµ,R0µ (W mm ) for all k ≥ 0. Note that it follows from Lemma 13 that
⟨∇L(W), W mm /∥W mm ∥F ⟩ < 0, for any finite W ∈ Cµ,R0µ (W mm ). Hence, there are no finite critical points W ∈ Cµ,R0µ (W mm ),
for which ∇L(W) = 0. Now, based on Lemma 9, which guarantees that ∇L(W(k)) → 0, this implies that ∥W (t)∥F → ∞.
Consequently, for any choice of ϵ ∈ (0, µ/2) there is an iteration kϵ such that, for all k ≥ kϵ , W(k) ∈ Cµ,Rϵ (W mm ). Once
within Cµ,Rϵ (W mm ), multiplying both sides (81) by the stepsize η and using the gradient descent update, we get

W mm
* + * +
W(k)
W(k + 1) − W(k), ≥ (1 − ϵ) W(k + 1) − W(k),
∥W mm ∥F ∥W(k)∥F
(1 − ϵ)  
= ∥W(k + 1)∥2F − ∥W(k)∥2F − ∥W(k + 1) − W(k)∥2F
2∥W(k)∥F
!
1  
≥ (1 − ϵ) ∥W(k + 1)∥2F − ∥W(k)∥2F − ∥W(k + 1) − W(k)∥2F
2∥W(k)∥F
 
≥ (1 − ϵ) ∥W(k + 1)∥F − ∥W(k)∥F − ∥W(k + 1) − W(k)∥2F
 
≥ (1 − ϵ) ∥W(k + 1)∥F − ∥W(k)∥F − 2η (L(W(k)) − L(W(k + 1))) .

Here, the second inequality is obtained from ∥W(k)∥F ≥ 1/2; the third inequality follows since for any a, b > 0, we
have (a2 − b2 )/(2b) − (a − b) ≥ 0; and the last inequality uses Lemma 9.
Summing the above inequality over k ≥ kϵ gives

W mm C(ϵ, η)
* +
W(k)
, ≥1−ϵ+ , W(k) ∈ Cµ,Rϵ (W mm ),
∥W(k)∥F ∥W mm ∥F ∥W(k)∥F

where L⋆ ≤ L (W (k)) for all k ≥ 0, and

W mm
* +
C(ϵ, η) = W(kϵ ), − (1 − ϵ)∥W(kϵ )∥F − 2η(1 − ϵ)(L(W(kϵ )) − L⋆ ).
∥W mm ∥F

Consequently, as k → ∞

W mm
* +
W(k)
lim inf , ≥ 1 − ϵ, W(k) ∈ Cµ,Rϵ (W mm ).
k→∞ ∥W(k)∥F ∥W mm ∥F

Since ϵ ∈ (0, µ/2) is arbitrary, we get W(k)/∥W(k)∥F → W mm /∥W mm ∥F . ■

E Convergence of Regularization Path for Sequence-to-Sequence Setting


In this section, we provide proofs for the regularization path analysis. We first provide a more general formulation of the
optimization problem that allows for regressing multiple token outputs. To distinguish from (W-ERM)&(KQ-ERM),
let us call this more general version Sequence Empirical Risk Minimization (SERM).
Problem definition: Rather than a single input sequence X, let us allow for two input sequences X ∈ RT ×d and
Z ∈ RK×d with (xt )Tt=1 and (zk )k=1
K
. The cross-attention admits X, Z and outputs K tokens. We will also allow for K
K
separate prediction heads (hk )k=1 for individual cross attention outputs which strictly generalizes the setting where we
used single prediction head h(·). Denote the training labels associated to each token as Y = (Yik )i=1
K
. Given n samples

49
(Yi , Xi , Z i )ni=1 , for a decreasing loss function ℓ(·), minimize the empirical risk by the prediction of first attention output
either univariate (W ∈ Rd×d ) or bivariate (K, Q ∈ Rd×m ) fashion:
n K
1 XX
L(W) = ℓ(Yik · hk (Xi⊤ S(XiW zik ))), (SERM-W)
n i=1 k=1
n
1X
L(K, Q) = ℓ(Yik · hk (Xi⊤ S(Xi KQ⊤ zik ))). (SERM-KQ)
n i=1

In order to recover the single-output self-attention model, we can simply set K = 1 and zi1 ← xi1 and hk ← h. To
proceed, we introduce the more general version of the (Att-SVM) problem, which we refer to as Sequential Cross-
Attention SVM to preserve consistent phrasing. Suppose K, Q ∈ Rd×m with m ≤ d and let Rm denote the set of rank-m
matrices in Rd×d . Given indices α = (αik )(n,K)
ik=(1,1) , consider the SVM with ⋄-norm constraint
mm
W⋄,α ∈ arg min ∥W∥⋄ s.t. min(xiαik − xit )⊤W zik ≥ 1 ∀ i ∈ [n], k ∈ [K]. (SAtt-SVM)
W∈Rm t,αik

When solution is non-unique, we denote the solution set by Wmm (α). In what follows, we denote Fikt := xit z⊤ik and
given α, we denote xαik = xiαik and Fαik := xiαik z⊤ik . With this notation, we can equivalently write

∈ arg min ∥W∥⋄ s.t. min Fαik − Fikt , W ≥ 1 ∀ i ∈ [n], k ∈ [K].


D E
mm
W⋄,α (SAtt-SVM’)
W∈Rm t,αik

Definition 4 (Support indices and Locally-Optimal Indices) Fix token indices α = (αik )(n,k)
ik=(1,1) for which (SAtt-SVM)
is feasible to obtain Wαmm := W⋄,α
mm
. Define token scores as
γikt = Yik · hk (xit ), γαik := γikαik = Yik · hk (xαik ).

Consider tokens Tik ⊂ [T ] such that Fαik − Fikt , Wαmm = 1 for all t ∈ Tik . Tik is allowed to be an empty set. We refer
D E

to Tik as support indices of Fαik = xαik z⊤ik and define its complement T̄ik = [T ] − Tik − {αik }. Additionally, token indices
α
α = (αik )(n,k)
ik=(1,1) are called locally-optimal if for all i ∈ [n], k ∈ [K] and t ∈ Tik , token scores obey γik > γikt . Associated
Wα is called a locally-optimal direction. Finally, let optik ∈ arg maxt∈[T ] γikt be the optimal indices and define the
mm

associated W mm (opt) to be a globally-optimal direction.


Lemma 15 (Mapping regularization path of (K, Q) to W) Let K, Q ∈ Rd×m and consider regularization path solu-
tions of (SERM-W) and (SERM-KQ)
W̄R ∈ arg min L(W) (91)
W∈Rm :∥W∥⋆ ≤R

K̄R , Q̄R ∈ arg min L(K, Q). (92)


∥K∥2F +∥Q∥2F ≤2R

For all R ≥ 0, there is a one-to-one map between the set of solutions W̄R of (91) and K̄R Q̄⊤R of (92).
Proof. To prove the mapping, first fix a W̄R solution with rank m, set LF = L(W̄R ) and show the existence of K, Q with
KQ⊤ = W̄R feasible for (92) and L(K, Q) ≤ LF . Use√the singular value √ decomposition W̄R = UΣV with Σ ∈ R
⊤ m×m

being diagonal matrix of singular values. Set K = U Σ and Q = V Σ. Observe that KQ⊤ = W and
m p
X
∥K∥2F = ∥Q∥2F = Σii 2 = ∥W̄R ∥⋆ ≤ R.
i=1

Thus, K, Q achieves L(K, Q) = LF . Conversely, given K̄R , Q̄R with L⋆ = L( K̄R , Q̄R ), W = K̄R Q̄⊤R obeys L(W) = L⋆
and, using the standard nuclear norm inequality, we have
1
∥W∥⋆ = ∥ K̄R Q̄⊤R ∥⋆ ≤ (∥ K̄R ∥2F + ∥Q̄R ∥2F ) = R.
2
This shows W is feasible for (91). Combining the two findings above, we find that optimal costs are equal (L⋆ = LF )
and for any ( K̄R , Q̄R ) solution there exists a W̄R solution and vice versa.
To proceed with our analysis, let us define the set of optimal solutions Wαmm := W⋄,α
mm
to (SAtt-SVM). Let us denote
mm mm mm
this set by Wα := W⋄ (α). Note that, if the ⋄-norm is not strongly-convex, Wα may not be singleton.

50
E.1 Local regularization path and proof of Theorem 7
We first recall local regularization path which solves the ⋄-norm-constrained problem over a conic slice C, namely
W̄(R) = min∥W∥⋄ ≤R,W∈C L(W). We will show that proper local RP directionally converge to locally-optimal directions.
Setting the cone to be the rank-m manifold Rm (and more specifically the set of all matrices Rd×d ), this will also establish
the convergence of global RP to the globally-optimal direction.
Our cone definition coneϵ (α) is induced by a token selection α = (αik )(n,k)
ik=(1,1) and has a simple interpretation: It
prioritizes tokens with lower score than α over tokens with high-score than α. This way lower score tokens create a
barrier for α and prevents optimization to move towards higher score tokens.
Definition 5 (Low&High Score Tokens and Separating Cone) Given α ∈ [T ], input sequence X with label Y, h(·) :
Rd → R, and score γt = Y · h(xt ) for all t ∈ [T ], define the low and high score tokens as

lowα (X) = t ∈ [T ] γt < γα }, highα (X) = {t ∈ [T ] − {α} γt ≥ γα .


n o

For input Xik and index αik , we use the shorthand notations lowαik , highαik . Finally define coneϵ (α) as
( )
coneϵ (α) = W ∈ Rm min maxα min α ⟨Fikt − Fikτ , W⟩ ≥ ϵ∥W∥F . (93)
i∈[n] t∈lowik τ∈highik

Lemma 16 Consider the cone definition of (93) and suppose an SVM solution Wαmm exists. If indices α are locally-
optimal, Wαmm ∈ coneϵ (α) for all sufficiently small ϵ > 0. Otherwise, Wαmm < coneϵ (α) for all ϵ > 0. Additionally,
suppose optimal indices optik ∈ arg maxt∈[T ] γikt are unique. Then, coneϵ (opt) = Rm .

Proof. Suppose α is locally optimal. Observe that, thanks to local optimality, Wαmm obeys

min Fikt , Wαmm > max Fikτ , Wαmm ,


t∈Tik τ<Tik ∪{αik }

for all i ∈ [n]. Next, observe that Tik ⊆ lowαik and highαik ⊆ T̄ik = [T ] − Tik − {αik }. Thus, the inequality (93) holds for
small enough ϵ > 0.
Conversely, suppose α is not locally-optimal. Fix support index t ∈ Tik with t ∈ highαik . Since t ∈ Tik , observe that

Fikt , Wαmm ≥ max Fikτ , Wαmm .


τ,αik

In other words, for this i ∈ [n], we found

max Fikτ − Fikt , Wαmm ≤ 0,


τ∈lowαik

violating (93) definition for any ϵ > 0. To show the final claim, observe that, setting α := opt, we have that highαik = ∅
for all i ∈ [n] as optik are unique optimal indices. Thus, there is no constraint enforced on the cone definition in (93)
making it equal to the rank-m manifold Rm .
Our main assumption regarding prediction head is a monotonicity condition which is a strict generalization of
linearity: We ask for h to preserve the order of token scores under convex combinations.
Assumption E (h preserves the top score) The functions (hk )k=1 K
are Lh -Lipschitz in Euclidean distance. Given α =
(αik )ik=(1,1) , there exists a scalar c := cα > 0 such that for all i ∈ [n], k ∈ [K] the following holds: Consider any convex
(n,k)

combination X X
x(s) = st · xit where st = 1, st ≥ 0.
t∈lowαik ∪{αik } t∈lowαik ∪{αik }

We have that Yik · hk (x(s)) ≤ γαik − c(1 − sαik ) where γαik = Yik · hk (xαik ) is the score of αik .
This condition states that convex combinations of tokens with scores lower than αik cannot achieve a score higher
than αik . Here, 1 − sαik term denotes the total share of non-optimal tokens. We require this condition to hold over the
training dataset rather than the full domain Rd . Crucially, it is a strict generalization of the linearity assumption: Any

51
linear hk satisfies Assumption E by setting cα > 0 to be the difference between the score of αik and the largest score
within lowαik i.e.

cα := min {γαik − maxα γikt } > 0. (94)


i∈[n],k∈[K] t∈lowik

This can be seen by writing hk (x(s)) = t∈lowαik ∪{αik } st · γikt = γαik + t∈lowαik st · (γikt − γαik ) ≤ γαik − cα (1 − sαik ). To
P P
provide a nonlinear example, consider the setting all labels are Yik = 1 and h is an arbitrary convex function. Thanks to
convexity, we can write h(x(s)) ≤ Tt=1 st · h(xt ) = t∈lowαik ∪{αik } st · γikt . Thus, we can use the same choice of cα in (94).
P P
We remark that, in Section 6, we derive formulae describing general inductive bias of attention without enforcing
any assumption on h. These formulae allow for arbitrary output tokens generated by the transformer model trained by
gradient descent. This includes the setting where gradient descent selects and composes multiple tokens from each
sequence rather than a single token αik .
The following result is our main theorem regarding the convergence of regularization path to locally-optimal
directions when restricted over the cone coneϵ (α).
Theorem 9 (Convergence of Local Regularization Path) Suppose (SAtt-SVM) is feasible and α = (αik )(n,k)
ik=(1,1) are
locally-optimal token indices. Suppose Assumptions A&E hold. Recall coneϵ (α) of (93) and consider the norm-
constrained cone \
C⋄ϵ,R0 := coneϵ (α) {W ∥W∥⋄ ≥ R0 }.
Define the conic regularization path W̄(R) = minC⋄ϵ,R ,∥W∥⋄ ≤R L(W). Let Wαmm be its set of minima and Ξ⋄ > 0 be the
0

 W̄(R) i.e. Ξ⋄ = 1/∥Wα ∥⋄ . For any sufficiently small ϵ > 0 and sufficiently large R0 = O(1/ϵ) > 0,
mm
associated margin
limR→∞ dist RΞ⋄ , Wα = 0. Additionally, suppose optimal indices opt = (optik )ik=(1,1) are unique and set α ← opt.
mm (n,k)

Then, the same RP convergence guarantee holds with C⋄ϵ,R0 = Rm .

Proof. We will prove that W̄(R) is the optimal direction and also ∥W̄(R)∥⋄ → ∞. Define the absolute constant

c⋄ = min ∥W∥F .
∥W∥⋄ =1

This guarantees that for any W we have ∥W∥F ≥ c⋄ ∥W∥⋄ . Also denote ε⋄ = c⋄ ϵ. Let us first determine the ϵ parameter:
Fix Wαmm ∈ Wαmm . For general α, we can choose any ϵ > 0 that is sufficiently small to guarantee Wαmm ∈ coneϵ (α)
based on Lemma 16. For α = opt, our analysis will entirely avoid using ϵ, specifically, observe that coneϵ (α) = Rm
based on Lemma 16.
Step 1: Let us first prove that W̄(R) achieves the optimal risk as R → ∞ – rather than problem having finite optima.
Define norm-normalized W̄ mm = Ξ⋄Wαmm . Note that Wαmm separates tokens α from rest of the tokens for each
i, k ∈ [n] × [K]. Thus, we have that
n K
1 XX
lim L(W̄(R)) ≤ lim L(R · W̄ mm ) := L⋆ = ℓ(γαik ). (95)
R→∞ R→∞ n i=1 k=1

ik = =
PT
On the other hand, for any choice of W ∈ coneϵ (α), set xW t=1 S(Xi W zik )t xt . Set softmax probabilities s
(ik)
α α
S(XiW zik ). Recalling lowik , highik definitions, we can decompose the attention features as
X X
(ik) α
ik = sαik xik +
xW t xit +
s(ik) τ xiτ .
s(ik) (96)
t∈lowαik τ∈highαik

When α = opt, note that we simply have highαik = ∅. This will be important for setting R0 = 0 and C⋄ϵ,R0 = Rm in the
proof for opt indices.
Set γikt = γikt − γαik = Yik · (hk (xit ) − hk (xαik )). Building on Lh -Lipschitzness of the prediction head hk (·), we define
gap

gap
B := max max Lh · ∥xit − xiτ ∥ ≥ |γikt |. (97)
i∈[n],k∈[K] t,τ∈[T ]

Define Pik := t∈lowαik s(ik) ik


t∈highαik st , and γik = Yik · hk (xik ). Also set temporary variables x = (sαik +
(ik) W W ′ (ik)
P P
t , Q :=
Qik )xαik + t∈lowαik s(ik)
t xit and γ = Yik · hk (xik ). Using Assumption E on x and noticing P = 1 − sαik − Q , observe that
′ W ′ ik (ik) ik
P

ik − γ | ≤ BQ
|γW ′ ik
and γ′ ≤ γαik − cα Pik .

52
Recall from (94) that, when hk are linear functions, cα can be chosen as
gap
cα := min min −γikt > 0.
i∈[n],k∈[K] t∈lowαik

To summarize, applying Assumption E, we obtain the following score inequalities


α
γW
ik ≤ γik − cα P + BQ ,
ik ik
(98)
X
|γW
ik − γαik | ≤ Lh ∥xW
ik − xαik ∥ ≤ Lh s(ik)
t ∥xikt − xαik ∥ ≤ B(1 − s(ik)
αik ). (99)
t,αik

α
We will use the γWik − γik term in (98) to evaluate W against the reference loss (95). Let a
(ik)
= XiW zik . Now since
(ik)
α
W ∈ coneϵ (α), there exists t ∈ lowik obeying at − maxτ∈highαik aτ ≥ ϵ∥W∥F ≥ ε⋄ ∥W∥⋄ . Denote Dik := ( t∈[T ] e at )−1
(ik) (ik) P
to be the softmax denominator i.e. sum of exponentials. We find that,
(ik) (ik)
X X
Qik = τ = D
s(ik) ik
e aτ ≤ Dik T e at −ϵ∥W∥F ≤ T e−ε⋄ ∥W∥⋄ Pik . (100)
τ∈highαik τ∈highαik

Consequently, the score difference obeys


α
γW
ik − γik ≤ BQ − cα P ≤ (BT e
ik ik −ε⋄ ∥W∥⋄
− cα )Pik .

Above, the right hand side is strictly negative as soon as ∥W∥⋄ ≥ R0 := ε1⋄ log BT
cα . Note that, this condition applies to all
(i, k) ∈ [n] × [K] pairs uniformly for the same R0 . Consequently, for any ∥W∥⋄ ≥ R0 , for all i, k and W ∈ coneϵ (α), we
α α
have that γW ik < γik . Additionally, when α = opt, note that Q = 0 since highik = ∅. Thus, R0 = 0 suffices to ensure
ik
α
γik < γik . Using the strictly-decreasing nature of ℓ, we conclude with the fact that for all (finite) W ∈ coneϵ (α),
W

n K n K
1 XX 1 XX
L(W) = ℓ(γW
ik ) > L⋆ = ℓ(γαik ),
n i=1 k=1 n i=1 k=1

which implies ∥W̄(R)∥⋄ → ∞.


Step 2: To proceed, we show that W̄(R) converges in direction to Wαmm . Suppose this is not the case i.e. convergence
fails. We will obtain a contradiction by showing that W̄Rmm = R · W̄ mm achieves a strictly superior loss compared to
W̄(R). Also define the normalized parameter W̄0 (R) = W̄(R)
RΞ⋄ and W = ∥W̄(R)∥⋄ Ξ⋄ . Note that W̄0 (R) is obtained by scaling
′ W̄(R)

down W ′ since ∥W̄(R)∥⋄ ≤ R and W ′ obeys ∥W ′ ∥⋄ = ∥Wαmm ∥⋄ .  


Since W̄0 (R) fails to converge to Wαmm , for some δ > 0, there exists arbitrarily large R > 0 such that dist W̄0 (R), Wαmm ≥
δ. This translates to the
 suboptimality in terms of margin constraints as follows: First, distance with respect to the
⋄-norm obeys dist⋄ W̄0 (R), Wαmm ≥ δ for some updated δ ← c⋄ δ. Secondly, using triangle inequality,

This implies that either ∥W̄0 (R)∥⋄ ≤ ∥Wαmm ∥⋄ − δ/2 or dist⋄ W ′ , Wαmm ≥ δ/2.


D ⋄ ≤ ∥Wα ∥⋄ − δ/2,
mm
In either scenario, W̄0 (R) strictly violates one of the margin constraints of (SAtt-SVM): If ∥W̄0 (R)∥
α
E
then, since the optimal SVM objective is ∥Wα ∥⋄ , there exists a constraint (i, k) for which Fik − Fikt , W̄0 (R) ≤
mm

1 − 2∥Wδαmm ∥⋄ . If dist⋄ W ′ , Wαmm ≥ δ/2, then, W ′ has same SVM objective but it is strictly bounded away from the


solution set. Thus, for some ϵ := ϵ(δ) > 0, W ′ and its scaled down version W̄0 (R) strictly violate an SVM constraint
achieving margin ≤ 1 − ϵ. Without losing generality, suppose W̄0 (R) violates the first constraint. Thus, for a properly
updated δ > 0 (that is function of the initial δ > 0) and for (i, k) = (1, 1) and some support index τ ∈ T11 ,

Fα11 − F11t , W̄0 (R) ≤ 1 − δ.


D E
(101)

Now, we will argue that this will lead to a contradiction by proving L(W̄Rmm ) < L(W̄(R)) for sufficiently large R.
To obtain the result, we establish a refined softmax probability control as in Step 1 by studying distance to L⋆ .
Following (98), denote the score function at W̄(R) via γRik = γW̄(R)
ik as shorthand notation. Similarly, let sRik = S(aRik ) with
aik = XiW̄(R)zik . Set the corresponding notation for the reference parameter W̄Rmm as γ⋆ik , s⋆ik , a⋆ik .
R

53
Critically, recall the above inequalities (100) that applies to both W ∈ {W̄(R), W̄Rmm } ⊂ coneϵ (α) for an index (i, k)
and support index t ∈ Tik
X X
Qik = sikτ = Dik e aikτ
τ∈highαik τ∈highαik
aikt −ε⋄ ∥W∥⋄
ik
≤ D Te ≤ T e−ε⋄ ∥W∥⋄ Pik ≤ T e−ε⋄ ∥W∥⋄ (1 − sikαik ), (102)

where Pik = τ∈lowαik sikτ and Pik + Qik = 1 − sikαik .


P
Note that, setting R0 ≥ O(1/ε⋄ ) = O(1/ϵ), we guarantee that, for any (i, k) ∈ [n] × [K]

Pik ≥ Qik =⇒ Pik ≥ 0.5(1 − sikαik ). (103)

Additionally, when α = opt, note that Qik = 0 since highαik = ∅. Thus, R0 = 0 suffices to ensure (103).
To proceed, recall that R ≥ ∥W̄(R)∥⋄ ≥ R0 by definition since W̄(R) ∈ C⋄ϵ,R0 and recall Ξ⋄ := 1/∥Wαmm ∥⋄ . Equipped
with these, we note the following softmax inequalities on the selected tokens αik
1
s⋆ikαik ≥ ≥ 1 − T e−RΞ⋄ for all (i, k) ∈ [n] × [K], (104)
1 + T e−RΞ⋄
1 1
sRikαik ≤ ≤ for (i, k) = (1, 1).
1 + e−(1−δ)∥W̄(R)∥⋄ Ξ⋄ 1 + e−(1−δ)RΞ⋄
The former inequality is thanks to Wαmm achieving ≥ 1 margins on all tokens [T ] − αik and the latter arises from the
δ-margin violation of W̄(R) at (i, k) = (1, 1) i.e. Eq. (101). Since ℓ is strictly decreasing with Lipschitz gradient and the
K
scores are upper/lower bounded by an absolute constant (as tokens are bounded, (hk )k=1 are Lipschitz, and both are
fixed), we know that cup ≥ −ℓ (γik ) ≥ cdn for some constants cup > cdn > 0. Thus, following Eq. (97) and the score
′ W

decomposition (98), and using (102),(103),(104) we can write


1 cdn α
L(W̄(R)) − L⋆ ≥ [ℓ(γW̄(R) ) − ℓ(γα11 )] ≥ (γ − γW̄(R)
11 )
n 11 n 11
cdn 11 11
≥ (cα PW̄(R) − BQW̄(R) ) (105)
n
cdn
≥ (1 − sR11α11 )(0.5cα − BT e−ε⋄ ∥W̄(R)∥⋄ )
n
cdn 1
≥ (0.5cα − BT e−ε⋄ R0 ).
n 1+e (1−δ)RΞ ⋄

Above, recalling the choice R0 ≥ O(1/ε⋄ ) = O(1/ϵ), R ≥ R0 implies BT e−ε⋄ R0 ≤ cα /4 to obtain


cdn · cα 1
L(W̄(R)) − L⋆ ≥ . (106)
4n 1 + e(1−δ)RΞ⋄
Additionally when α = opt, since QW̄(R) 11
= 0 in (105), the bound above holds with R0 = 0 by directly using (105).
Conversely, we upper bound the difference between L(W̄Rmm ) and L⋆ as follows. Define the worst-case loss
difference for W̄(R) as (i′ , k′ ) = arg maxi∈[n],k∈[K] [ℓ(γ⋆ik ) − ℓ(γαik )]. Using (99)&(104), we write

L(W̄Rmm ) − L⋆ ≤ max [ℓ(γ⋆ik ) − ℓ(γαik )] ≤ cup · (γαi′ k′ − γ⋆i′ k′ )


i∈[n],k∈[K]

≤ cup · (1 − s⋆i′ k′ αi′ k′ )B (107)


−RΞ⋄
≤ cup · T e B.

Combining the last inequality and (106), we conclude that L(W̄Rmm ) < L(W̄(R)) whenever

cdn · cα 1 eRΞ⋄ 4cup T nB


cup T · e−RΞ⋄ B < ⇐⇒ > .
4n 1 + e (1−δ)RΞ ⋄ 1+e (1−δ)RΞ⋄ cdn cα
8cup T nB
The left hand-side inequality holds for all sufficiently large R: Specifically, as soon as R obeys R > 1
δΞ⋄ log( cdn cα ).
This completes the proof of the theorem via contradiction as we obtained L(W̄(R)) > L(W̄Rmm ).

54
E.2 Global regularization path
The following result is a direct corollary of Theorem 9. Namely, we simply restate the final line of this theorem that
applies to optimal tokens.
Corollary 1 (Global Convergence of Regularization Path) Suppose Assumptions A&E hold and the optimal indices
optik = arg maxt∈[T ] γikt are unique. Consider the global regularization path W̄⋄,R = minW∈Rm ,∥W∥⋄ ≤R L(W). Let W⋄mm
be the non-empty solution set of (SAtt-SVM) with α ← opt normalized to have unit ⋄-norm. Then
!
W̄⋄,R
lim dist , W⋄
mm
R→∞ R
The next corollary directly targets application to (SERM-W) and (SERM-KQ). This corollary is also a strict
generalization of Theorem 2. Specifically, we immediately recover Theorem 2 by specializing this to the single-output
setting K ← 1 and full-dimensional parameterization m ← d.
Corollary 2 Suppose Assumptions A&E hold and the optimal indices optik = arg maxt∈[T ] γikt are unique. Consider
the regularization paths associated to (SERM-W) and (SERM-KQ):
W̄R = arg min L(W) and K̄R , Q̄R = arg min L(K, Q) (108)
W∈Rm ,∥W∥F ≤R ∥K∥2F +∥Q∥2F ≤2R

Suppose (SAtt-SVM) is feasible for α ← opt. Let W mm be the unique solution of (SAtt-SVM) with Frobenius norm
and W⋆mm be the solution set of (SAtt-SVM) with nuclear norm and cost function ∥W⋆mm ∥⋆ . We have that
 Q̄R K̄R⊤ W⋆mm 
 
W̄R W mm
lim = , lim dist  ,  = 0.
R→∞ R ∥W mm ∥F R→∞ R ∥W⋆mm ∥⋆

Proof. We directly apply Corollary 1 with ⋄ = F and ⋄ = ⋆ respectively. To obtain the result on W̄R , we note that
W mm is unique because Frobenius norm-squared is strongly convex. To obtain the result on (Q̄R , K̄R ), we use Lemma
15 and observe that
W̄⋆,R := Q̄R K̄R⊤ ∈ arg min L(W).
W∈Rm ,∥W∥⋆ ≤R

We then apply Corollary 1 with ⋄ = ⋆ to conclude with the convergence of the path W̄⋆,R .

E.2.1 Proof of Theorem 2


Corollary 2 already proves Theorem 2 through the more general Theorem 9. Below, we provide a self contained proof
of Theorem 2 for clarity.
Proof. Throughout ⋄ denotes either Frobenius norm or nuclear norm. We will prove that W̄(R) asymptotically aligns
with the set of globally-optimal directions and also ∥W̄(R)∥⋄ → ∞. Rm ⊆ Rd×d denote the manifold of rank ≤m matrices.
Step 1: Let us first prove that W̄(R) achieves the optimal risk as R → ∞ – rather than problem having finite optima.
Define Ξ⋄ = 1/∥W mm ∥⋄ and norm-normalized W̄ mm = Ξ⋄W mm . Note that W mm separates tokens opt from rest of the
tokens for each i ∈ [n]. Thus, we have that
n
1X
lim L(W̄(R)) ≤ lim L(R · W̄ mm
) := L⋆ = ℓ(γopt
i ). (109)
R→∞ R→∞ n i=1

On the other hand, for any W ∈ Rm , define the softmax probabilities s(i) = S(XiW zi ) and attention features xW i =
PT (i) gap
W W
= (i)
+ (i)
. γ = γ opt
γ = ⊤
>
P
s
t=1 t xt . Decompose xi as x i s x
opti iopti s
t,opti t x it Set it i − it Yi · v (x iopti − xit ) 0, and define
gap
B := max max ∥v∥ · ∥xit − xiτ ∥ ≥ γit . (110)
i∈[n] t,τ∈[T ]
gap
Define copt = mini∈[n],t,opti γit > 0 and γW
i = Yi · v xi . We obtain the following score inequalities
⊤ W

γW
i ≤ γi
opt
opti ) < γi ,
− copt (1 − s(i) opt
(111)
X
|γW
i − γopt
i | ≤ ∥v∥ · ∥xW
i − xαi ∥ ≤ ∥v∥ s(i)
t ∥xit − xαi ∥ ≤ B(1 − s(i)
opti ).
t,opti

55
We will use the γWi − γi
opt
term in (111) to evaluate W against the reference loss L⋆ of (109). Using the strictly-
decreasing nature of ℓ, we conclude with the fact that for all (finite) W ∈ Rm ,
n n
1X 1X
L(W) = ℓ(γW
i ) > L ⋆ = ℓ(γopt
i ),
n i=1 n i=1

which implies ∥W̄(R)∥⋄ → ∞ together with (109).


Step 2: To proceed, we show that W̄(R) converges in direction to Wmm , which denotes the set of SVM minima. Suppose
this is not the case and convergence fails. We will obtain a contradiction by showing that W̄Rmm = R · W̄ mm achieves a
strictly superior loss compared to W̄(R). Let us introduce the normalized parameters W̄0 (R) = W̄(R)
RΞ⋄ and W = ∥W̄(R)∥⋄ Ξ⋄ .
′ W̄(R)

Note that W̄0 (R) is obtained by scaling down W ′ since ∥W̄(R)∥⋄ ≤ R and W ′ obeys ∥W ′ ∥⋄ = ∥Wmm ∥⋄ . Since W̄  0 (R)
fails to converge to Wmm , for some δ > 0, there exists arbitrarily large R > 0 such that dist W̄0 (R), Wmm ≥ δ.
This translates to the suboptimality in terms of the margin constraints
 as follows:
 First, since nuclear norm dominates
Frobenius, distance with respect to the ⋄-norm obeys dist⋄ W̄0 (R), Wmm ≥ δ. Secondly, using triangle inequality,

this implies that either ∥W̄0 (R)∥⋄ ≤ ∥W mm ∥⋄ − δ/2 or dist⋄ W ′ , Wmm ≥ δ/2.


In either scenario, W̄0 (R) strictly violates one of the margin constraints of (SAtt-SVM): If ∥W̄ D 0 (R)∥ ⋄ ≤ ∥W
mm
∥⋄ − δ/2,
E
then, since the optimal SVM objective is ∥W ∥⋄ , there exists a constraint i, t , opti for which (xi − xit )zi , W̄0 (R) ≤
mm opt ⊤

1 − 2∥Wδmm ∥⋄ . If dist⋄ (W ′ , Wmm ) ≥ δ/2, then, W ′ has the same SVM objective but it is strictly bounded away from the
solution set. Thus, for some ϵ := ϵ(δ) > 0, W ′ and its scaled down version W̄0 (R) strictly violate an SVM constraint
achieving margin ≤ 1 − ϵ. Without losing generality, suppose W̄0 (R) violates the first constraint i = 1. Thus, for a
properly updated δ > 0 (that is function of the initial δ > 0) and for i = 1 and some support index τ ∈ T1 ,
D opt E
(x1 − x1t )z⊤1 , W̄0 (R) ≤ 1 − δ. (112)

Now, we will argue that this leads to a contradiction by proving L(W̄Rmm ) < L(W̄(R)) for sufficiently large R.
To obtain the result, we establish a refined softmax probability control as in Step 1 by studying distance to L⋆ .
Following (111), denote the score function at W̄(R) via γRi := γW̄(R)
i . Similarly, let sRi = S(aRi ) with aRi = XiW̄(R)zi .
Set the corresponding notation for the reference parameter W̄R as γ⋆i , s⋆i , a⋆i . Recall that R ≥ ∥W̄(R)∥⋄ and Ξ⋄ :=
mm

1/∥W mm ∥⋄ . We note the following softmax inequalities


1
s⋆iopti ≥ ≥ 1 − T e−RΞ⋄ for all i ∈ [n], (113)
1 + T e−RΞ⋄
1 1
sRiopti ≤ ≤ for i = 1.
1+e −(1−δ)∥W̄(R)∥⋄ Ξ⋄ 1+e−(1−δ)RΞ⋄

The former inequality is thanks to W mm achieving ≥1 margins on all tokens [T ] − opti and the latter arises from the
δ-margin violation of W̄(R) at i = 1 i.e. Eq. (112). Since ℓ is strictly decreasing with Lipschitz derivative and the scores
are upper/lower bounded by an absolute constant (as tokens are bounded and fixed), we have that cup ≥ −ℓ′ (γW i ) ≥ cdn
for some constants cup > cdn > 0. Thus, following Eq. (110), the score decomposition (111), and (113) we can write
1 cdn opt
L(W̄(R)) − L⋆ ≥ [ℓ(γW̄(R)
1 ) − ℓ(γopt W̄(R)
1 )] ≥ n (γ1 − γ1 ) (114)
n
cdn
≥ copt (1 − sR1opt1 ).
n
cdn copt 1
≥ .
n 1 + e(1−δ)RΞ⋄
Conversely, we upper bound the difference between L(W̄Rmm ) and L⋆ as follows. Define the worst-case loss difference
for W̄(R) as j = arg maxi∈[n] [ℓ(γ⋆i ) − ℓ(γopt
i )]. Using (111)&(113), we write

L(W̄Rmm ) − L⋆ ≤ max[ℓ(γ⋆i ) − ℓ(γopt opt


i )] ≤ cup · (γ j − γ⋆j )
i∈[n]

≤ cup · (1 − s⋆jopt j )B
≤ cup · T e−RΞ⋄ B.

56
coefficient
T=5 100

coefficient
d=4 d=4
1 correlation coefficient
10 1 T = 10
d=6
T = 15 d=6

1-correlation coefficient
1-correlation coefficient
10 2 d=8 d=8
10 1 d = 10 10 2 d = 10
1−correlation

1−correlation
n=5
10 1 n = 10
n = 15
10 2 10 2

1 3 5 7 9 0 500 1000 1500 2000 10 16 10 14 10 12 10 10 10


Softmax threshold
8 10 6
Varyingmm
Varying Iterations
Iterations Masked token threshold (Γ)
(a) Evolution of correlation under varying d (b) Γ vs correlation coefficient

Figure 11: Convergence behavior of GD Figure 12: Behavior of GD with nonlinear nonconvex prediction head
when training attention weights (K, Q) ∈ and multi-token compositions. (a): Blue, green, red and teal curves
Rd×m with random data and varying m. The represent the evolution of 1−corr_coef(W, W SVMeq ) for d = 4, 6, 8 and
misalignment between attention SVM and 10 respectively, which have been displayed in Figure 9(upper). (b): Over
GD, 1 − corr_coef(W⋆,αmm
, KQ⊤ ), is studied. the 500 random instances as discussed in Figure 9, we filter different
W⋆,α is from (Att-SVM⋆ ) with GD tokens α
mm instances by constructing masked set with tokens whose softmax output
and m = d. Subfigures with fixed n = 5 and < Γ and vary Γ from 10−16 to 10−6 . The corresponding results of
T = 5 show that as m approaches or exceeds 1 − corr_coef(W, W SVMeq ) are displayed in blue, green, red and teal
n, KQ⊤ aligns more with W⋆,α
mm
. curves.

Combining the last inequality and (114), we conclude that L(W̄Rmm ) < L(W̄(R)) whenever

cdn · copt 1 eRΞ⋄ cup T nB


cup T · e−RΞ⋄ B < ⇐⇒ > .
n 1+e (1−δ)RΞ⋄ 1+e(1−δ)RΞ⋄ cdn copt
2cup T nB
The left hand-side inequality holds for all sufficiently large R: Specifically, as soon as R obeys R > 1
δΞ⋄ log( cdn copt ).
This completes the proof of the theorem by contradiction since we obtained L(W̄(R)) > L(W̄Rmm ).

F Supporting Experiments
In this section, we introduce implementation details and additional experiments. Code is available at
https://github.com/umich-sota/TF-as-SVM
We create a 1-layer self-attention using PyTorch, training it with the SGD optimizer and a learning rate of η = 0.1. We
apply normalized gradient descent to ensure divergence of attention weights. The attention weight W is then updated
through
∇L(W(k))
W(k + 1) = W(k) − η .
∥∇L(W(k))∥F
In the setting of (K, Q)-parameterization, we noted that with extended training iterations, the norm of the combined
parameter KQ⊤ consistently rises, despite the gradient being treated as zero due to computational limitations. To tackle
this issue, we introduce a minor regularization penalty to the loss function, ensuring that the norms of K and Q remain
within reasonable bounds. This adjustment involves
e Q) = L(K, Q) + λ(∥K∥2 + ∥Q∥2 ).
L(K, F F

Here, we set λ to be the the smallest representable number, e.g. computed as 1 + λ , 1 in Python, which is around
2.22 × 10−16 . Therefore, K, Q parameters are updated as follows.

∇LeK (K(k), Q(k)) ∇LeQ (K(k), Q(k))


K(k + 1) = K(k) − η , Q(k + 1) = Q(k) − η .
eK (K(k), Q(k))∥F
∥∇L eQ (K(k), Q(k))∥F
∥∇L

57
8
=3
1.2 7 0.8 =5

tokens
=7

of selected tokens
1.0

Probabilities
6 0.6 =9

Probabilities
# of# selected
0.8 5
λ

0.4
0.6 4
3 0.2
0.4
2 0.0
2 3 4 5 6 7 8 9 2 3 4 5 6 7 8 9 2 3 4 5 6 7 8 9
τ τ # of#selected
of selected tokens
tokens

(a) τ and λ parameters relationship (b) τ and # of selected tokens relationship (c) Distribution of # selected tokens over varying τ

Figure 13: Behavior of GD when selecting multiple tokens.

• As observed in previous work [TLZO23], and due to the exponential expression of softmax nonlinearity and
computation limitation, PyTorch has no guarantee to select optimal tokens when the score gap is too small. Therefore
in Figures 2, 5 and 6, we generate random tokens making sure that mini∈[n],t,opti γiopti − γit ≥ γ and we choose γ = 0.1
in our experiments.

Rank sensitivity of (K, Q)-parameterization (Figure 11). In Figure 4 and Lemma 1, we have both theoretically
and empirically established that the rank of the SVM solution, denoted as W mm in (Att-SVM) or W⋆mm in (Att-SVM⋆ ),
is at most rank max(n, d). Now, moving to Figure 11, we delve into GD performance across various dimensions of
K, Q ∈ Rd×m while keeping d = 20 fixed and varying m from 1 to 10. In the upper subfigure, we maintain a constant
n = 5 and vary T within {5, 10, 15}, while in the lower subfigure, T is fixed at 5 and n changes within {5, 10, 15}.
mm
Results are depicted using blue, green, and red dashed curves, with both y-axes representing 1 − corr_coef(W, W⋆,α ),
where W represents the GD solution and W⋆,α is obtained from (Att-SVM⋆ ) by employing token indices α selected
mm

via GD and setting the rank limit to m = d. Observing both subfigures, we note that a larger n necessitates a larger m
for attention weights KQ⊤ to accurately converge to the SVM solution (Figure 11(lower)). Meanwhile, performances
remain consistent across varying T values (Figure 11(upper)). This observation further validates Lemma 1. Furthermore,
mm
the results demonstrate that W converges directionally towards W⋆,α as long as m ≳ n, thereby confirming the assertion
in our Theorem 5.

Behavior of GD with nonlinear nonconvex prediction head and multi-token compositions (Figure 12). To better
investigate how correlation changes with data dimension d, we collect the solid curves in Figure 9(upper) and construct
as Figure 12(a). Moreover, Figure 12(b) displays the average correlation of instances (refer to scatters in Figure 9
(lower)), considering masked tokens with softmax probability < Γ. Both findings highlight that higher d enhances
alignment. For d ≥ 8 or Γ ≤ 10−9 , the GD solution W achieves a correlation of > 0.99 with the SVM-equivalence
W SVMeq , defined in Section 6.

Investigation of Lemma 6 over different τ selections (Figure 13). Consider the setting of Section 6.1 and Lemma 6.
Figure 10 explores the influence of λ on the count of tokens selected by GD-derived attention weights. As λ increases,
the likelihood of selecting more tokens also increases. Shifting focus to Figure 13, we examine the effect of τ. For each
outcome, we generate random λ values, retaining pairs (λ, X) satisfying τ constraints, with averages derived from 100
successful trials. The results indicate a positive correlation among τ, λ, and the number of selected tokens. Moreover,
Figure 13(c) provides a precise distribution of selected token counts across various τ values (specifically τ ∈ {3, 5, 7, 9}).
The findings confirm that the number of selected tokens remains within the limit of τ, thus validating the assertion made
in Lemma 6.

58

You might also like