Skip to content

Commit 39c18af

Browse files
authored
Merge branch 'site' into fix-previous-versions
2 parents b94ae30 + f22086d commit 39c18af

12 files changed

+211
-0
lines changed

_posts/2024-10-28-unleashing-ai-mobile.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
layout: blog_detail
33
title: "Unleashing the Power of AI on Mobile: LLM Inference for Llama 3.2 Quantized Models with ExecuTorch and KleidiAI"
44
author: Gian Marco Iodice, Arm and Digant Desai, Meta
5+
excerpt: "At the recent PyTorch Conference, Arm highlighted the widespread impact of its technology, spanning from cloud to edge, emphasizing its commitment to delivering its advanced AI computing capabilities seamlessly to millions of developers worldwide."
56
---
67

78
## Introduction

_posts/2024-11-01-cutlass-ping-pong-gemm-kernel.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
layout: blog_detail
33
title: "Deep Dive on CUTLASS Ping-Pong GEMM Kernel"
44
author: Less Wright, Adnan Hoque
5+
excerpt: "In this post, we provide an overview, with relevant FP8 inference kernel benchmarking, of the CUTLASS Ping-Pong GEMM kernel."
56
---
67

78
![Figure 1. FP8 GEMM Throughput Comparison CUTLASS vs Triton](/assets/images/cutlass-ping-pong-gemm-kernel/fg1.png){:style="width:100%"}

_posts/2024-11-21-rebellions.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
---
22
layout: blog_detail
33
title: "Rebellions Joins the PyTorch Foundation as a General Member"
4+
excerpt: "The PyTorch Foundation, a neutral home for the deep learning community to collaborate on the open source PyTorch framework and ecosystem, is announcing today that Rebellions has joined as a general member."
45
---
56

67
![Rebellions logo](/assets/images/rebellions-logo.svg){:style="max-width:350px;width:100%;float:right;margin: 20px;"}

_posts/2024-11-25-training-using-float8-fsdp2.md

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
layout: blog_detail
33
title: "Supercharging Training using float8 and FSDP2"
44
author: "IBM and Meta"
5+
excerpt: "In this blog, we will demonstrate how we achieve up to 50% throughput speedup while achieving loss and evaluation benchmark parity in training over FSDP1 bf16 training"
56
---
67

78
**IBM**: Tuan Hoang Trong, Alexei Karve, Yan Koyfman, Linsong Chu, Divya Kumari, Shweta Salaria, Robert Walkup, Praneet Adusumilli, Nirmit Desai, Raghu Ganti, Seetharami Seelam

_posts/2024-12-02-hadacore.md

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
---
2+
layout: blog_detail
3+
title: "HadaCore: Tensor Core Accelerated Hadamard Transform Kernel"
4+
author: "IBM and Meta"
5+
excerpt: "Quantization is a method for improving model inference speeds by compressing model weights and performing (faster) computation in lower precision data types. However, quantization can result in accuracy loss due to the presence of outliers."
6+
---
7+
8+
**IBM**: Krish Agarwal, Rishi Astra, Adnan Hoque, Mudhakar Srivatsa, Raghu Ganti
9+
**Meta**: Less Wright, Sijia Chen
10+
11+
Quantization is a method for improving model inference speeds by compressing model weights and performing (faster) computation in lower precision data types. However, quantization can result in accuracy loss due to the presence of outliers. Recent works like [QuaRot](https://arxiv.org/abs/2404.00456), [SpinQuant](https://arxiv.org/abs/2405.16406), and [FlashAttention-3](https://arxiv.org/pdf/2407.08608) introduce methods to increase the numerical accuracy of INT4, INT8 and FP8 quantization in LLMs. These methods rely on [Hadamard Transforms](https://en.wikipedia.org/wiki/Hadamard_transform). In this blog, we present HadaCore, a Hadamard Transform CUDA kernel that achieves state-of-the-art performance on NVIDIA A100 and H100 GPUs. Our kernel achieves speedups of **1.1–1.4x** and **1.0–1.3x**, with a peak gain of **3.5x** and **3.6x** respectively, over Dao AI Lab’s [Fast Hadamard Transform Kernel](https://github.com/Dao-AILab/fast-hadamard-transform). We leverage a hardware-aware work decomposition that benefits from Tensor Core acceleration while maintaining quantization error reduction.
12+
13+
14+
15+
![Figure 1: Speedup of HadaCore vs Dao AI Hadamard CUDA kernel. A peak gain of 3.46x on the A100 is achieved using 128 rotation by 8.4M elements.](/assets/images/hadacore/fg1.png){:style="width:100%"}
16+
17+
*Figure 1: Speedup of HadaCore vs Dao AI Hadamard CUDA kernel. A peak gain of 3.46x on the A100 is achieved using 128 rotation by 8.4M elements.*
18+
19+
The [HadaCore Kernel is publicly available](https://github.com/pytorch-labs/applied-ai/tree/main/kernels/cuda/inference/hadamard_transform).
20+
21+
## Background
22+
23+
[QuaRot](https://arxiv.org/abs/2404.00456) and [SpinQuant](https://arxiv.org/abs/2405.16406) both propose methods to increase the numerical accuracy of INT4 and INT8 quantization in LLMs. Both methods rotate model activations since rotations are statistically likely to reduce the magnitude of outliers, as it “distributes” extreme values among other (less extreme) dimensions, and rotation is also an easily invertible operation using the inverse of the rotation matrix. These methods can also improve FP8 inference accuracy, such as in [FlashAttention-3](https://arxiv.org/pdf/2407.08608).
24+
25+
26+
![Figure 2. Transformer block showing online (red) and offline rotations (blue) in QuaRot](/assets/images/hadacore/fg2.png){:style="width:100%"}
27+
28+
29+
*Figure 2. Transformer block showing online (red) and offline rotations (blue) in QuaRot*
30+
31+
Applying these rotation matrices introduces model runtime overhead due to the online operations shown in Figure 2. These rotations can be applied through matrix multiplication, but the added overhead would diminish the benefits from quantization. Therefore, QuaRot and SpinQuant opt to use Walsh-Hadamard matrices, a special type of rotation matrix that can be applied faster than matrix multiplication using the [Fast Walsh-Hadamard Transform](https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform) algorithm. HadaCore is an optimized implementation of this algorithm for NVIDIA GPUs that support Tensor Cores.
32+
33+
## Tensor Core Accelerated Hadamard Transform
34+
35+
HadaCore leverages [NVIDIA Tensor Cores](https://www.nvidia.com/en-us/data-center/tensor-cores/), which are specialized compute units on NVIDIA GPUs optimized for matrix multiplication. To achieve this, our kernel performs a hardware-aware work decomposition of the Fast Walsh-Hadamard algorithm. This work decomposition ensures that we can utilize the [MMA PTX instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=mma#multiply-and-accumulate-instruction-mma) that execute on the Tensor Core chip. HadaCore applies a 16×16 Hadamard transform to chunks of the input data. The computation can then be offloaded to the FP16 Tensor Core with usage of the [mma.m16n8k16](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=mma#matrix-fragments-for-mma-m16n8k16-with-floating-point-type) instruction. The warp-level parallelism for HadaCore is shown below.
36+
37+
38+
![Figure 3: HadaCore Parallelization, 1x256 vectors (rows) being rotated by a size 256 Hadamard.](/assets/images/hadacore/fg3.png){:style="width:100%"}
39+
40+
41+
*Figure 3: HadaCore Parallelization, 1x256 vectors (rows) being rotated by a size 256 Hadamard.*
42+
43+
We process fragments of 256 elements in parallel using warp-level Tensor Core operations to achieve up to a 256-size Hadamard transform. For further sizes, we shuffle data between warps and repeat.
44+
45+
## Microbenchmarks
46+
47+
We benchmark HadaCore against the[ Dao AI Lab Hadamard Kernel](https://github.com/Dao-AILab) on both NVIDIA H100 and A100 GPUs across varying Hadamard and input tensor sizes.
48+
49+
![Figure 4: HadaCore Kernel Speedup on NVIDIA A100 over Dao AI Lab Fast Hadamard Kernel](/assets/images/hadacore/fg4.png){:style="width:100%"}
50+
51+
52+
53+
*Figure 4: HadaCore Kernel Speedup on NVIDIA A100 over Dao AI Lab Fast Hadamard Kernel*
54+
55+
56+
![Color coded Speedup Table for NVIDIA A100, Green = Speedup over Baseline](/assets/images/hadacore/fg5.png){:style="width:100%; margin-top: 35px;"}
57+
58+
59+
*Color coded Speedup Table for NVIDIA A100, Green = Speedup over Baseline*
60+
61+
62+
![Figure 5: HadaCore Kernel Speedup on NVIDIA H100 over Dao AI Lab Fast Hadamard Kernel](/assets/images/hadacore/fg6.png){:style="width:100%; margin-top: 35px;"}
63+
64+
65+
*Figure 5: HadaCore Kernel Speedup on NVIDIA H100 over Dao AI Lab Fast Hadamard Kernel*
66+
67+
68+
![Color coded Speedup Table for NVIDIA H100, Green = Speedup over Baseline](/assets/images/hadacore/fg7.png){:style="width:100%; margin-top: 35px;"}
69+
70+
71+
*Color coded Speedup Table for NVIDIA H100, Green = Speedup over Baseline*
72+
73+
We showcase our speedup as the input tensor size (labeled element count) in our charts increase. Element count is the number of elements in the target matrix we are rotating. For example, in multi-head attention:
74+
75+
76+
The queries (Q), keys (K) and values (V) tensors are 4D tensors of size:
77+
78+
`(batch_size, seq_len, n_heads, head_dim)`
79+
80+
A Hadamard matrix of size `head_dim` is applied to these activation tensors, so we refer to this as using a Hadamard size of `head_dim` with an element count of:
81+
82+
`batch_size*seq_len*n_heads*head_dim.`
83+
84+
Common element counts for query rotations in an attention block:
85+
86+
87+
<table class="table table-bordered">
88+
<tr>
89+
<td><strong>Model \ Tokens</strong>
90+
</td>
91+
<td><strong>Prefill</strong>
92+
</td>
93+
<td><strong>Decoding</strong>
94+
</td>
95+
</tr>
96+
<tr>
97+
<td><strong>Llama-2 70b</strong>
98+
</td>
99+
<td>33,554,432 elements
100+
<br>
101+
128 Hadamard size
102+
<br>
103+
104+
(1 batch * 64 heads * 4096 tokens * 128 dimensional embeddings per head per token)
105+
</td>
106+
<td>8192 elements
107+
<br>
108+
128 Hadamard size
109+
<br>
110+
(1 batch * 64 heads * 1 token * 128 dimensional embeddings per head per token)
111+
</td>
112+
</tr>
113+
<tr>
114+
<td><strong>Llama-3 8b</strong>
115+
</td>
116+
<td>33,554,432 elements
117+
<br>
118+
128 Hadamard size
119+
<br>
120+
(1 batch * 32 heads * 8192 tokens * 128 dimensional embeddings per head per token)
121+
</td>
122+
<td>4,096 elements
123+
<br>
124+
128 Hadamard size
125+
<br>
126+
(1 batch * 32 heads * 1 token * 128 dimensional embeddings per head per token)
127+
</td>
128+
</tr>
129+
</table>
130+
131+
132+
HadaCore achieves **1.1–1.4x** speedup on A100 and **1.0–1.3x** speedup on H100 over Dao AI Lab’s Fast Hadamard kernel, with a peak gain of **3.5x and 3.6x**, respectively. For smaller sizes on H100, HadaCore’s gain decreases. For future work, we plan to incorporate usage of Hopper specific features like TMA and WGMMA for improved H100 performance.
133+
134+
## MMLU Benchmarks
135+
136+
We evaluated MMLU scores on a [Llama 3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) inference workload where the FlashAttention computation was performed in FP8. Newer generation [NVIDIA Hopper GPUs ](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/)come equipped with FP8 Tensor Cores that deliver substantial compute gain over FP16.
137+
138+
Our results show the benefit of using HadaCore for accuracy preservation when combined with optimizations such as FP8 FlashAttention.
139+
140+
141+
<table class="table table-bordered">
142+
<tr>
143+
<td><strong>Format</strong>
144+
</td>
145+
<td><strong>Method</strong>
146+
</td>
147+
<td><strong>Llama3.1-8B</strong>
148+
<br>
149+
<strong>Avg. 5-Shot MMLU Accuracy</strong>
150+
</td>
151+
</tr>
152+
<tr>
153+
<td><strong>Q, K, V: FP16</strong>
154+
<br>
155+
<strong>FlashAttention: FP16</strong>
156+
</td>
157+
<td>N/A
158+
</td>
159+
<td>65.38
160+
</td>
161+
</tr>
162+
<tr>
163+
<td><strong>Q, K, V: FP16</strong>
164+
<br>
165+
<strong>FlashAttention: FP8</strong>
166+
</td>
167+
<td>No Hadamard
168+
</td>
169+
<td>64.40
170+
</td>
171+
</tr>
172+
<tr>
173+
<td><strong>Q, K, V: FP8</strong>
174+
<br>
175+
<strong>FlashAttention: FP8</strong>
176+
</td>
177+
<td>HadaCore
178+
</td>
179+
<td>65.09
180+
</td>
181+
</tr>
182+
<tr>
183+
<td><strong>Q, K, V: FP8</strong>
184+
<br>
185+
<strong>FlashAttention: FP8</strong>
186+
</td>
187+
<td>Dao AI Fast Hadamard Kernel
188+
</td>
189+
<td>65.45
190+
</td>
191+
</tr>
192+
</table>
193+
194+
195+
*Table 1: MMLU scores for Llama3.1 8B with FP16 baseline and FP8 attention using Hadamard transforms, comparing an implementation with explicit Hadamard matrix multiplications vs. HadaCore (**higher is better**)*
196+
197+
From the above MMLU scores, we note that for Llama3.1-8B inference with FP8 attention, HadaCore improves the quantization error introduced from computing attention in a lower precision.
198+
199+
## Conclusion
200+
201+
We showcased our speedups achieved by moving the Fast-Walsh Hadamard algorithm into a CUDA kernel that leverages Tensor Core acceleration and achieves a peak speedup of **3.5x** and **3.6x** over the Dao AI Fast-Hadamard kernel on NVIDIA A100 and H100, respectively.
202+
203+
Further, we showed on the MMLU benchmark that rotating with HadaCore maintains similar quantization error reduction to the Fast-Hadamard kernel, while providing computational acceleration.
204+
205+
## Future Work
206+
207+
We plan to implement a Triton version of our kernel and experiment with more advanced techniques such as kernel fusion to support fused Hadamard transform and quantization. Further, we plan to extend our kernel to support BF16 Tensor Core compute.

assets/images/hadacore/fg1.png

247 KB
Loading

assets/images/hadacore/fg2.png

59.7 KB
Loading

assets/images/hadacore/fg3.png

22.1 KB
Loading

assets/images/hadacore/fg4.png

247 KB
Loading

assets/images/hadacore/fg5.png

81.4 KB
Loading

assets/images/hadacore/fg6.png

236 KB
Loading

assets/images/hadacore/fg7.png

85 KB
Loading

0 commit comments

Comments
 (0)