Skip to content

Commit e0da7bd

Browse files
authored
fix: softmax.cu spell error (#371)
* softmax.cu spell error * format * format
1 parent 0d05779 commit e0da7bd

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
<div align='center'>
4545
<img src='https://github.com/user-attachments/assets/a5ec4320-d2f9-4254-888a-170b2d9e3784' height=170px>
4646
</div>
47-
-->
47+
-->
4848

4949
- [2025-01-08]: **[🤖ffpa-attn](https://github.com/xlite-dev/ffpa-attn.git)** is released! Yet another Faster Flash Prefill Attention with O(1)🎉SRAM complexity for large headdim, **1.8x~3x↑**🎉 vs SDPA EA: [📈L20 ~1.9x↑🎉](https://github.com/xlite-dev/ffpa-attn?tab=readme-ov-file#L1-bench-l20), [📈A30 ~1.8x↑🎉](https://github.com/xlite-dev/ffpa-attn?tab=readme-ov-file#L1-bench-a30),[📈4090 ~2.1x↑🎉](https://github.com/xlite-dev/ffpa-attn?tab=readme-ov-file#L1-bench-4090).
5050

@@ -54,7 +54,7 @@
5454
<img src='https://github.com/user-attachments/assets/447e2937-f7c8-47c8-8550-8c0c71b910e6' height="170px" width="229px">
5555
<img src='https://github.com/user-attachments/assets/65a8d564-8fa7-4d66-86b9-e238feb86143' height="170px" width="229px">
5656
</div>
57-
-->
57+
-->
5858
<div align='center'>
5959
<img height="320px" alt="image" src="https://github.com/user-attachments/assets/ed30185b-2e11-4293-832f-43e9003d6ad9" />
6060
</div>

kernels/softmax/softmax.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,14 +333,14 @@ __global__ void online_safe_softmax_f32_per_token_kernel(const float *x,
333333
// for softmax)
334334
int local_tid = threadIdx.x;
335335
int global_tid = blockIdx.x * NUM_THREADS + threadIdx.x;
336-
const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
336+
const int WARP_NUM = NUM_THREADS / WARP_SIZE;
337337
int warp_id = local_tid / WARP_SIZE;
338338
int lane_id = local_tid % WARP_SIZE;
339339
MD val;
340340
val.m = global_tid < N ? x[global_tid] : -FLT_MAX;
341341
val.d = global_tid < N ? 1.0f : 0.0f;
342342

343-
__shared__ MD shared[WAPR_NUM];
343+
__shared__ MD shared[WARP_NUM];
344344
MD res = warp_reduce_md_op<WARP_SIZE>(val);
345345

346346
if (lane_id == 0)
@@ -349,7 +349,7 @@ __global__ void online_safe_softmax_f32_per_token_kernel(const float *x,
349349

350350
if (local_tid < WARP_SIZE) {
351351
MD block_res = shared[local_tid];
352-
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
352+
block_res = warp_reduce_md_op<WARP_NUM>(block_res);
353353
if (local_tid == 0) {
354354
shared[0] = block_res;
355355
}
@@ -371,7 +371,7 @@ online_safe_softmax_f32x4_pack_per_token_kernel(float *x, float *y, int N) {
371371
int local_tid = threadIdx.x;
372372
int global_tid = (blockIdx.x * NUM_THREADS + local_tid) * 4;
373373

374-
const int WAPR_NUM = NUM_THREADS / WARP_SIZE;
374+
const int WARP_NUM = NUM_THREADS / WARP_SIZE;
375375
int warp_id = local_tid / WARP_SIZE;
376376
int lane_id = local_tid % WARP_SIZE;
377377
// compare local max value
@@ -382,15 +382,15 @@ online_safe_softmax_f32x4_pack_per_token_kernel(float *x, float *y, int N) {
382382

383383
MD local_md = {local_m, local_d};
384384
MD res = warp_reduce_md_op<WARP_SIZE>(local_md);
385-
__shared__ MD shared[WAPR_NUM];
385+
__shared__ MD shared[WARP_NUM];
386386

387387
if (lane_id == 0)
388388
shared[warp_id] = res;
389389
__syncthreads();
390390
// do block reduce
391391
if (local_tid < WARP_SIZE) {
392392
MD block_res = shared[local_tid];
393-
block_res = warp_reduce_md_op<WAPR_NUM>(block_res);
393+
block_res = warp_reduce_md_op<WARP_NUM>(block_res);
394394
if (local_tid == 0)
395395
shared[0] = block_res;
396396
}

0 commit comments

Comments
 (0)