From a462c36a4b247076f171883d1517debb38780d10 Mon Sep 17 00:00:00 2001 From: Horace He Date: Mon, 12 Aug 2024 22:45:33 -0700 Subject: [PATCH 1/2] Update 2024-08-07-flexattention.md --- _posts/2024-08-07-flexattention.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_posts/2024-08-07-flexattention.md b/_posts/2024-08-07-flexattention.md index 4c34879d33b6..d601bc085c58 100644 --- a/_posts/2024-08-07-flexattention.md +++ b/_posts/2024-08-07-flexattention.md @@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx): return causal_mask & window_mask # If you want to be cute... -from torch.nn.attention import or_masks +from torch.nn.attention import and_masks def sliding_window(b, h, q_idx, kv_idx) return q_idx - kv_idx <= SLIDING_WINDOW -sliding_window_causal = or_masks(causal_mask, sliding_window) +sliding_window_causal = and_masks(causal_mask, sliding_window) ``` We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity. @@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti - The Jax team's work on SplashAttention - Philippe Tillet and Keren Zhou for helping us with Triton - Ali Hassani for discussions on neighborhood attention -- Everybody who's complained about attention kernels not supporting their favorite attention variant :) \ No newline at end of file +- Everybody who's complained about attention kernels not supporting their favorite attention variant :) From 8d5c958534f476c07e7129fc8bfa9e8931a1a78d Mon Sep 17 00:00:00 2001 From: Horace He Date: Thu, 6 Feb 2025 11:39:16 -0800 Subject: [PATCH 2/2] Update 2024-08-07-flexattention.md --- _posts/2024-08-07-flexattention.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_posts/2024-08-07-flexattention.md b/_posts/2024-08-07-flexattention.md index d601bc085c58..acfc1fc40f01 100644 --- a/_posts/2024-08-07-flexattention.md +++ b/_posts/2024-08-07-flexattention.md @@ -1,7 +1,7 @@ --- layout: blog_detail title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention" -author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong" +author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He" --- ![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"} @@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a alibi_bias = generate_alibi_bias() # [num_heads] def alibi(score, b, h, q_idx, kv_idx): - bias = alibi_bias[h] * (q_idx - kv_idx) + bias = alibi_bias[h] * (kv_idx - q_idx) return score + bias ```