diff --git a/README.md b/README.md index de6ed662..52496262 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,12 @@ | ✔️ [gelu_f16x2](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️| | ✔️ [gelu_f16x8](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️| | ✔️ [gelu_f16x8_pack](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️⭐️| +| ✔️ [swish_f32](./swish/swish.cu)|f32|/|[link](./swish/)|⭐️| +| ✔️ [swish_f32x4](./swish/swish.cu)|f32|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16x2](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16x8](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16x8_pack](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️⭐️| | ✔️ [embedding_f32](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f32x4](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f32x4_pack](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️| diff --git a/swish/.gitignore b/swish/.gitignore new file mode 100644 index 00000000..eb33da95 --- /dev/null +++ b/swish/.gitignore @@ -0,0 +1,10 @@ +*.so +*.a +*.dylib +*.dll +*.lib +.DS_Store +build +*.whl +tmp + diff --git a/swish/README.md b/swish/README.md new file mode 100644 index 00000000..0cdc0362 --- /dev/null +++ b/swish/README.md @@ -0,0 +1,136 @@ +# Swish + +## 0x00 说明 + +包含以下内容: + +- [X] swish_f32_kernel +- [X] swish_f32x4_kernel(float4向量化版本) +- [X] swish_f16_kernel(fp16版本) +- [X] swish_f16x2_kernel(fp16向量化版本) +- [X] swish_f16x8_kernel(fp16向量化版本) +- [X] swish_f16x8_pack_kernel(fp16向量化,pack版本) +- [X] PyTorch bindings + + +## 测试 + +```bash +# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... +export TORCH_CUDA_ARCH_LIST=Ada +python3 swish.py +``` + +输出: + +```bash +------------------------------------------------------------------------------------- + S=1024, K=1024 + out_f32: ['0.46177661 ', '-0.10888041 '], time:0.01246500ms + out_f32x4: ['0.46177661 ', '-0.10888041 '], time:0.01006508ms + out_f32_th: ['0.46177667 ', '-0.10888041 '], time:0.03012419ms +------------------------------------------------------------------------------------- + out_f16: ['0.46191406 ', '-0.10894775 '], time:0.01299334ms + out_f16x2: ['0.46191406 ', '-0.10894775 '], time:0.01036119ms + out_f16x8: ['0.46191406 ', '-0.10894775 '], time:0.00979590ms + out_f16x8pack: ['0.46191406 ', '-0.10894775 '], time:0.00972557ms + out_f16_th: ['0.46191406 ', '-0.10888672 '], time:0.02423882ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=1024, K=2048 + out_f32: ['-0.27797085 ', '0.71514565 '], time:0.01415992ms + out_f32x4: ['-0.27797085 ', '0.71514565 '], time:0.01159716ms + out_f32_th: ['-0.27797085 ', '0.71514559 '], time:0.02964258ms +------------------------------------------------------------------------------------- + out_f16: ['-0.27807617 ', '0.71582031 '], time:0.01473880ms + out_f16x2: ['-0.27807617 ', '0.71582031 '], time:0.01404881ms + out_f16x8: ['-0.27807617 ', '0.71582031 '], time:0.01127148ms + out_f16x8pack: ['-0.27807617 ', '0.71582031 '], time:0.01101518ms + out_f16_th: ['-0.27807617 ', '0.71533203 '], time:0.02657008ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=1024, K=4096 + out_f32: ['0.29988611 ', '-0.2541697 '], time:0.01959276ms + out_f32x4: ['0.29988611 ', '-0.2541697 '], time:0.01605868ms + out_f32_th: ['0.29988611 ', '-0.25416973 '], time:0.03745818ms +------------------------------------------------------------------------------------- + out_f16: ['0.30004883 ', '-0.25415039 '], time:0.02078271ms + out_f16x2: ['0.30004883 ', '-0.25415039 '], time:0.01729155ms + out_f16x8: ['0.30004883 ', '-0.25415039 '], time:0.01489425ms + out_f16x8pack: ['0.30004883 ', '-0.25415039 '], time:0.01351643ms + out_f16_th: ['0.29980469 ', '-0.25415039 '], time:0.03149080ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=1024 + out_f32: ['-0.07777861 ', '-0.27842814 '], time:0.01640201ms + out_f32x4: ['-0.07777861 ', '-0.27842814 '], time:0.01180029ms + out_f32_th: ['-0.07777861 ', '-0.27842814 '], time:0.02952218ms +------------------------------------------------------------------------------------- + out_f16: ['-0.07775879 ', '-0.27856445 '], time:0.01758027ms + out_f16x2: ['-0.07775879 ', '-0.27856445 '], time:0.01236153ms + out_f16x8: ['-0.07775879 ', '-0.27856445 '], time:0.01109338ms + out_f16x8pack: ['-0.07775879 ', '-0.27856445 '], time:0.01091790ms + out_f16_th: ['-0.07775879 ', '-0.27856445 '], time:0.02657914ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=2048 + out_f32: ['-0.14754841 ', '-0.21989606 '], time:0.01957679ms + out_f32x4: ['-0.14754841 ', '-0.21989606 '], time:0.01496792ms + out_f32_th: ['-0.14754841 ', '-0.21989603 '], time:0.03751612ms +------------------------------------------------------------------------------------- + out_f16: ['-0.14758301 ', '-0.21984863 '], time:0.02085924ms + out_f16x2: ['-0.14758301 ', '-0.21984863 '], time:0.01961517ms + out_f16x8: ['-0.14758301 ', '-0.21984863 '], time:0.01386237ms + out_f16x8pack: ['-0.14758301 ', '-0.21984863 '], time:0.01334929ms + out_f16_th: ['-0.14758301 ', '-0.21984863 '], time:0.03151488ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=4096 + out_f32: ['1.07876182 ', '-0.27844051 '], time:0.03036070ms + out_f32x4: ['1.07876182 ', '-0.27844051 '], time:0.02339220ms + out_f32_th: ['1.07876182 ', '-0.27844048 '], time:0.05310464ms +------------------------------------------------------------------------------------- + out_f16: ['1.078125 ', '-0.27832031 '], time:0.03291988ms + out_f16x2: ['1.078125 ', '-0.27832031 '], time:0.02590466ms + out_f16x8: ['1.078125 ', '-0.27832031 '], time:0.02027988ms + out_f16x8pack: ['1.078125 ', '-0.27832031 '], time:0.01811814ms + out_f16_th: ['1.07910156 ', '-0.27832031 '], time:0.04083204ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=1024 + out_f32: ['0.31169948 ', '-0.18232882 '], time:0.02427077ms + out_f32x4: ['0.31169948 ', '-0.18232882 '], time:0.01515222ms + out_f32_th: ['0.31169948 ', '-0.18232881 '], time:0.03754425ms +------------------------------------------------------------------------------------- + out_f16: ['0.31152344 ', '-0.18237305 '], time:0.02679300ms + out_f16x2: ['0.31152344 ', '-0.18237305 '], time:0.01617312ms + out_f16x8: ['0.31152344 ', '-0.18237305 '], time:0.01357770ms + out_f16x8pack: ['0.31152344 ', '-0.18237305 '], time:0.01324248ms + out_f16_th: ['0.31152344 ', '-0.18225098 '], time:0.03149295ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=2048 + out_f32: ['1.5033319 ', '0.17473438 '], time:0.03030729ms + out_f32x4: ['1.5033319 ', '0.17473438 '], time:0.02150083ms + out_f32_th: ['1.5033319 ', '0.17473438 '], time:0.05257607ms +------------------------------------------------------------------------------------- + out_f16: ['1.50390625 ', '0.17468262 '], time:0.03289509ms + out_f16x2: ['1.50390625 ', '0.17468262 '], time:0.03073120ms + out_f16x8: ['1.50390625 ', '0.17468262 '], time:0.01862860ms + out_f16x8pack: ['1.50390625 ', '0.17468262 '], time:0.01772857ms + out_f16_th: ['1.50390625 ', '0.17468262 '], time:0.04082441ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=4096 + out_f32: ['-0.05288643 ', '-0.14218464 '], time:0.19254756ms + out_f32x4: ['-0.05288643 ', '-0.14218464 '], time:0.19258785ms + out_f32_th: ['-0.05288643 ', '-0.14218464 '], time:0.48660636ms +------------------------------------------------------------------------------------- + out_f16: ['-0.052948 ', '-0.14221191 '], time:0.05689216ms + out_f16x2: ['-0.052948 ', '-0.14221191 '], time:0.04335928ms + out_f16x8: ['-0.052948 ', '-0.14221191 '], time:0.03096652ms + out_f16x8pack: ['-0.052948 ', '-0.14221191 '], time:0.02706647ms + out_f16_th: ['-0.05288696 ', '-0.14221191 '], time:0.05971408ms +------------------------------------------------------------------------------------- + +``` diff --git a/swish/swish.cu b/swish/swish.cu new file mode 100644 index 00000000..0438dcc8 --- /dev/null +++ b/swish/swish.cu @@ -0,0 +1,157 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WARP_SIZE 32 +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) + +// -------------------------------------- FP32 -------------------------------------- +// Swish x: N, y: N y=x*sigmoid(x) +__device__ __forceinline__ float swish(float x) { + return x / (1.0f + expf(-x)); +} + +__global__ void swish_f32_kernel(float* x, float* y, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) y[idx] = swish(x[idx]); +} + +__global__ void swish_f32x4_kernel(float* x, float* y, int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < N) { + float4 reg_x = FLOAT4(x[idx]); + float4 reg_y; + reg_y.x = swish(reg_x.x); + reg_y.y = swish(reg_x.y); + reg_y.z = swish(reg_x.z); + reg_y.w = swish(reg_x.w); + FLOAT4(y[idx]) = reg_y; + } +} + +// -------------------------------------- FP16 -------------------------------------- +__device__ __forceinline__ half swish_half(half x) { + return __hmul(x, __hdiv( + __float2half(1.0f), __hadd(__float2half(1.0f), hexp(__hneg(x))))); +} + +__global__ void swish_f16_kernel(half* x, half* y, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) y[idx] = swish_half(x[idx]); +} + +__global__ void swish_f16x2_kernel(half* x, half* y, int N) { + int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx < N) { + half2 reg_x = HALF2(x[idx]); + half2 reg_y; + reg_y.x = swish_half(reg_x.x); + reg_y.y = swish_half(reg_x.y); + HALF2(y[idx]) = reg_y; + } +} + +__global__ void swish_f16x8_kernel(half* x, half* y, int N) { + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + half2 reg_x_0 = HALF2(x[idx + 0]); + half2 reg_x_1 = HALF2(x[idx + 2]); + half2 reg_x_2 = HALF2(x[idx + 4]); + half2 reg_x_3 = HALF2(x[idx + 6]); + half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3; + reg_y_0.x = swish_half(reg_x_0.x); + reg_y_0.y = swish_half(reg_x_0.y); + reg_y_1.x = swish_half(reg_x_1.x); + reg_y_1.y = swish_half(reg_x_1.y); + reg_y_2.x = swish_half(reg_x_2.x); + reg_y_2.y = swish_half(reg_x_2.y); + reg_y_3.x = swish_half(reg_x_3.x); + reg_y_3.y = swish_half(reg_x_3.y); + if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; } + if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; } + if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; } + if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; } +} + +__global__ void swish_f16x8_pack_kernel(half* x, half* y, int N) { + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + half pack_x[8], pack_y[8]; + LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); + + #pragma unroll + for (int i = 0; i < 8; i++) { + pack_y[i] = swish_half(pack_x[i]); + } + if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); } +} + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define TORCH_BINDING_SWISH(packed_type, th_type, element_type, n_elements) \ +void swish_##packed_type(torch::Tensor x, torch::Tensor y) { \ + CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ + CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ + const int ndim = x.dim(); \ + if (ndim != 2) { \ + int N = 1; \ + for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ + dim3 block(256 / (n_elements)); \ + dim3 grid((N + 256 - 1) / 256); \ + swish_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } else { \ + const int S = x.size(0); \ + const int K = x.size(1); \ + const int N = S * K; \ + if ((K/(n_elements)) <= 1024) { \ + dim3 block(K/(n_elements)); \ + dim3 grid(S); \ + swish_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } else { \ + int N = 1; \ + for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ + dim3 block(256 / (n_elements)); \ + dim3 grid((N + 256 - 1) / 256); \ + swish_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } \ + } \ +} + +TORCH_BINDING_SWISH(f32, torch::kFloat32, float, 1) +TORCH_BINDING_SWISH(f32x4, torch::kFloat32, float, 4) +TORCH_BINDING_SWISH(f16, torch::kHalf, half, 1) +TORCH_BINDING_SWISH(f16x2, torch::kHalf, half, 2) +TORCH_BINDING_SWISH(f16x8, torch::kHalf, half, 8) +TORCH_BINDING_SWISH(f16x8_pack, torch::kHalf, half, 8) + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_BINDING_COMMON_EXTENSION(swish_f32) +TORCH_BINDING_COMMON_EXTENSION(swish_f32x4) +TORCH_BINDING_COMMON_EXTENSION(swish_f16) +TORCH_BINDING_COMMON_EXTENSION(swish_f16x2) +TORCH_BINDING_COMMON_EXTENSION(swish_f16x8) +TORCH_BINDING_COMMON_EXTENSION(swish_f16x8_pack) +} diff --git a/swish/swish.py b/swish/swish.py new file mode 100644 index 00000000..bc330235 --- /dev/null +++ b/swish/swish.py @@ -0,0 +1,85 @@ +import torch +import time +from torch.utils.cpp_extension import load +from typing import Optional +from functools import partial + +torch.set_grad_enabled(False) + +# Load the CUDA kernel as a python module +lib = load(name='swish_lib', + sources=['swish.cu'], + extra_cuda_cflags=[ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ], + extra_cflags=['-std=c++17']) + +def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str, + out: Optional[torch.Tensor] = None, warmup: int = 10, + iters: int = 1000, show_all: bool = False): + if out is not None: + out.fill_(0) + # warmup + if out is not None: + for i in range(warmup): + perf_func(x, out) + else: + for i in range(warmup): + out = perf_func(x) + torch.cuda.synchronize() + start = time.time() + # iters + if out is not None: + for i in range(iters): + perf_func(x, out) + else: + for i in range(iters): + out = perf_func(x) + torch.cuda.synchronize() + end = time.time() + total_time = (end - start) * 1000 # ms + mean_time = total_time / iters + out_info = f"out_{tag}" + out_val = out.flatten().detach().cpu().numpy().tolist()[:2] + out_val = [round(v, 8) for v in out_val] + out_val = [f"{v:<12}" for v in out_val] + print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms") + if show_all: print(out) + return out, mean_time + +def torch_swish(x, out=None): + if out is None: + return x * torch.sigmoid(x) + else: + torch.sigmoid(x, out=out) + out.mul_(x) + return out + +Ss = [1024, 2048, 4096] +Ks = [1024, 2048, 4096] +SKs = [(S, K) for S in Ss for K in Ks] + +for (S, K) in SKs: + print("-" * 85) + print(" " * 40 + f"S={S}, K={K}") + x = torch.randn((S, K)).cuda().float().contiguous() + y = torch.zeros_like(x).cuda().float().contiguous() + run_benchmark(lib.swish_f32, x, "f32", y) + run_benchmark(lib.swish_f32x4, x, "f32x4", y) + run_benchmark(torch_swish, x, "f32_th", y) + print("-" * 85) + x_f16 = x.half().contiguous() + y_f16 = y.half().contiguous() + run_benchmark(lib.swish_f16, x_f16, "f16", y_f16) + run_benchmark(lib.swish_f16x2, x_f16, "f16x2", y_f16) + run_benchmark(lib.swish_f16x8, x_f16, "f16x8", y_f16) + run_benchmark(lib.swish_f16x8_pack, x_f16, "f16x8pack", y_f16) + run_benchmark(torch_swish, x_f16, "f16_th", y_f16) + print("-" * 85)