Skip to content

Update 2024-08-07-flexattention.md #1707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions _posts/2024-08-07-flexattention.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
---
layout: blog_detail
title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"
author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong"
author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He"
---

![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"}
Expand Down Expand Up @@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a
alibi_bias = generate_alibi_bias() # [num_heads]

def alibi(score, b, h, q_idx, kv_idx):
bias = alibi_bias[h] * (q_idx - kv_idx)
bias = alibi_bias[h] * (kv_idx - q_idx)
return score + bias
```

Expand Down Expand Up @@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx):
return causal_mask & window_mask

# If you want to be cute...
from torch.nn.attention import or_masks
from torch.nn.attention import and_masks

def sliding_window(b, h, q_idx, kv_idx)
return q_idx - kv_idx <= SLIDING_WINDOW

sliding_window_causal = or_masks(causal_mask, sliding_window)
sliding_window_causal = and_masks(causal_mask, sliding_window)
```

We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity.
Expand Down Expand Up @@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti
- The Jax team's work on SplashAttention
- Philippe Tillet and Keren Zhou for helping us with Triton
- Ali Hassani for discussions on neighborhood attention
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)