From 5f517e283728b1ead1e7cba088c574e19d964c72 Mon Sep 17 00:00:00 2001 From: wangzijian <2087291150@qq.com> Date: Wed, 16 Oct 2024 07:48:51 +0000 Subject: [PATCH 1/5] [SWISH][FP16] first commit,add FP16 FP32 and fp16x8_pack kernel. --- swish/.gitignore | 10 +++ swish/README.md | 0 swish/swish.cu | 156 +++++++++++++++++++++++++++++++++++++++++++++++ swish/swish.py | 85 ++++++++++++++++++++++++++ 4 files changed, 251 insertions(+) create mode 100644 swish/.gitignore create mode 100644 swish/README.md create mode 100644 swish/swish.cu create mode 100644 swish/swish.py diff --git a/swish/.gitignore b/swish/.gitignore new file mode 100644 index 00000000..eb33da95 --- /dev/null +++ b/swish/.gitignore @@ -0,0 +1,10 @@ +*.so +*.a +*.dylib +*.dll +*.lib +.DS_Store +build +*.whl +tmp + diff --git a/swish/README.md b/swish/README.md new file mode 100644 index 00000000..e69de29b diff --git a/swish/swish.cu b/swish/swish.cu new file mode 100644 index 00000000..eebaa9b3 --- /dev/null +++ b/swish/swish.cu @@ -0,0 +1,156 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define WARP_SIZE 32 +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) + +// -------------------------------------- FP32 -------------------------------------- +// Swish x: N, y: N y=x*sigmoid(x) +__device__ __forceinline__ float swish(float x) { + return x / (1.0f + expf(-x)); +} + +__global__ void swish_f32_kernel(float* x, float* y, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) y[idx] = swish(x[idx]); +} + +__global__ void swish_f32x4_kernel(float* x, float* y, int N) { + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < N) { + float4 reg_x = FLOAT4(x[idx]); + float4 reg_y; + reg_y.x = swish(reg_x.x); + reg_y.y = swish(reg_x.y); + reg_y.z = swish(reg_x.z); + reg_y.w = swish(reg_x.w); + FLOAT4(y[idx]) = reg_y; + } +} + +// -------------------------------------- FP16 -------------------------------------- +__device__ __forceinline__ half swish_half(half x) { + return __hmul(x, __hdiv(__float2half(1.0f), __hadd(__float2half(1.0f), hexp(__hneg(x))))); +} + +__global__ void swish_f16_kernel(half* x, half* y, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) y[idx] = swish_half(x[idx]); +} + +__global__ void swish_f16x2_kernel(half* x, half* y, int N) { + int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx < N) { + half2 reg_x = HALF2(x[idx]); + half2 reg_y; + reg_y.x = swish_half(reg_x.x); + reg_y.y = swish_half(reg_x.y); + HALF2(y[idx]) = reg_y; + } +} + +__global__ void swish_f16x8_kernel(half* x, half* y, int N) { + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + half2 reg_x_0 = HALF2(x[idx + 0]); + half2 reg_x_1 = HALF2(x[idx + 2]); + half2 reg_x_2 = HALF2(x[idx + 4]); + half2 reg_x_3 = HALF2(x[idx + 6]); + half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3; + reg_y_0.x = swish_half(reg_x_0.x); + reg_y_0.y = swish_half(reg_x_0.y); + reg_y_1.x = swish_half(reg_x_1.x); + reg_y_1.y = swish_half(reg_x_1.y); + reg_y_2.x = swish_half(reg_x_2.x); + reg_y_2.y = swish_half(reg_x_2.y); + reg_y_3.x = swish_half(reg_x_3.x); + reg_y_3.y = swish_half(reg_x_3.y); + if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; } + if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; } + if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; } + if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; } +} + +__global__ void swish_f16x8_pack_kernel(half* x, half* y, int N) { + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + half pack_x[8], pack_y[8]; + LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); + +#pragma unroll + for (int i = 0; i < 8; i++) { + pack_y[i] = swish_half(pack_x[i]); + } + if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); } +} + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define TORCH_BINDING_SWISH(packed_type, th_type, element_type, n_elements) \ +void swish_##packed_type(torch::Tensor x, torch::Tensor y) { \ + CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ + CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ + const int ndim = x.dim(); \ + if (ndim != 2) { \ + int N = 1; \ + for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ + dim3 block(256 / (n_elements)); \ + dim3 grid((N + 256 - 1) / 256); \ + swish_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } else { \ + const int S = x.size(0); \ + const int K = x.size(1); \ + const int N = S * K; \ + if ((K/(n_elements)) <= 1024) { \ + dim3 block(K/(n_elements)); \ + dim3 grid(S); \ + swish_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } else { \ + int N = 1; \ + for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ + dim3 block(256 / (n_elements)); \ + dim3 grid((N + 256 - 1) / 256); \ + swish_##packed_type##_kernel<<>>( \ + reinterpret_cast(x.data_ptr()), \ + reinterpret_cast(y.data_ptr()), N); \ + } \ + } \ +} + +TORCH_BINDING_SWISH(f32, torch::kFloat32, float, 1) +TORCH_BINDING_SWISH(f32x4, torch::kFloat32, float, 4) +TORCH_BINDING_SWISH(f16, torch::kHalf, half, 1) +TORCH_BINDING_SWISH(f16x2, torch::kHalf, half, 2) +TORCH_BINDING_SWISH(f16x8, torch::kHalf, half, 8) +TORCH_BINDING_SWISH(f16x8_pack, torch::kHalf, half, 8) + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_BINDING_COMMON_EXTENSION(swish_f32) +TORCH_BINDING_COMMON_EXTENSION(swish_f32x4) +TORCH_BINDING_COMMON_EXTENSION(swish_f16) +TORCH_BINDING_COMMON_EXTENSION(swish_f16x2) +TORCH_BINDING_COMMON_EXTENSION(swish_f16x8) +TORCH_BINDING_COMMON_EXTENSION(swish_f16x8_pack) +} \ No newline at end of file diff --git a/swish/swish.py b/swish/swish.py new file mode 100644 index 00000000..bc330235 --- /dev/null +++ b/swish/swish.py @@ -0,0 +1,85 @@ +import torch +import time +from torch.utils.cpp_extension import load +from typing import Optional +from functools import partial + +torch.set_grad_enabled(False) + +# Load the CUDA kernel as a python module +lib = load(name='swish_lib', + sources=['swish.cu'], + extra_cuda_cflags=[ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ], + extra_cflags=['-std=c++17']) + +def run_benchmark(perf_func: callable, x: torch.Tensor, tag: str, + out: Optional[torch.Tensor] = None, warmup: int = 10, + iters: int = 1000, show_all: bool = False): + if out is not None: + out.fill_(0) + # warmup + if out is not None: + for i in range(warmup): + perf_func(x, out) + else: + for i in range(warmup): + out = perf_func(x) + torch.cuda.synchronize() + start = time.time() + # iters + if out is not None: + for i in range(iters): + perf_func(x, out) + else: + for i in range(iters): + out = perf_func(x) + torch.cuda.synchronize() + end = time.time() + total_time = (end - start) * 1000 # ms + mean_time = total_time / iters + out_info = f"out_{tag}" + out_val = out.flatten().detach().cpu().numpy().tolist()[:2] + out_val = [round(v, 8) for v in out_val] + out_val = [f"{v:<12}" for v in out_val] + print(f"{out_info:>18}: {out_val}, time:{mean_time:.8f}ms") + if show_all: print(out) + return out, mean_time + +def torch_swish(x, out=None): + if out is None: + return x * torch.sigmoid(x) + else: + torch.sigmoid(x, out=out) + out.mul_(x) + return out + +Ss = [1024, 2048, 4096] +Ks = [1024, 2048, 4096] +SKs = [(S, K) for S in Ss for K in Ks] + +for (S, K) in SKs: + print("-" * 85) + print(" " * 40 + f"S={S}, K={K}") + x = torch.randn((S, K)).cuda().float().contiguous() + y = torch.zeros_like(x).cuda().float().contiguous() + run_benchmark(lib.swish_f32, x, "f32", y) + run_benchmark(lib.swish_f32x4, x, "f32x4", y) + run_benchmark(torch_swish, x, "f32_th", y) + print("-" * 85) + x_f16 = x.half().contiguous() + y_f16 = y.half().contiguous() + run_benchmark(lib.swish_f16, x_f16, "f16", y_f16) + run_benchmark(lib.swish_f16x2, x_f16, "f16x2", y_f16) + run_benchmark(lib.swish_f16x8, x_f16, "f16x8", y_f16) + run_benchmark(lib.swish_f16x8_pack, x_f16, "f16x8pack", y_f16) + run_benchmark(torch_swish, x_f16, "f16_th", y_f16) + print("-" * 85) From 4ad3addf16827f0bd6a598e376be01c32b37ccaa Mon Sep 17 00:00:00 2001 From: wangzijian <2087291150@qq.com> Date: Wed, 16 Oct 2024 08:52:44 +0000 Subject: [PATCH 2/5] [SWISH][FP16] add README.md. --- swish/README.md | 136 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/swish/README.md b/swish/README.md index e69de29b..1854ba40 100644 --- a/swish/README.md +++ b/swish/README.md @@ -0,0 +1,136 @@ +# Swish + +## 0x00 说明 + +包含以下内容: + +- [X] swish_f32_kernel +- [X] swish_f32x4_kernel(float4向量化版本) +- [X] swish_f16_kernel(fp16版本) +- [X] swish_f16x2_kernel(fp16向量化版本) +- [X] swish_f16x8_kernel(fp16向量化版本) +- [X] swish_f16x8_pack_kernel(fp16向量化,pack版本) +- [X] PyTorch bindings + + +## 测试 + +```bash +# 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... +export TORCH_CUDA_ARCH_LIST=Ada +python3 relu.py +``` + +输出: + +```bash +------------------------------------------------------------------------------------- + S=1024, K=1024 + out_f32: ['-0.27626503 ', '-0.04989763 '], time:0.00399876ms + out_f32x4: ['-0.27626503 ', '-0.04989763 '], time:0.00343800ms + out_f32_th: ['-0.27626505 ', '-0.04989763 '], time:0.00745869ms +------------------------------------------------------------------------------------- + out_f16: ['-0.27636719 ', '-0.04989624 '], time:0.00438786ms + out_f16x2: ['-0.27636719 ', '-0.04989624 '], time:0.00286722ms + out_f16x8: ['-0.27636719 ', '-0.04989624 '], time:0.00279593ms + out_f16x8pack: ['-0.27636719 ', '-0.04989624 '], time:0.00246906ms + out_f16_th: ['-0.27636719 ', '-0.04989624 '], time:0.00530934ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=1024, K=2048 + out_f32: ['-0.11616411 ', '0.92168581 '], time:0.00554204ms + out_f32x4: ['-0.11616411 ', '0.92168581 '], time:0.00527859ms + out_f32_th: ['-0.11616411 ', '0.92168581 '], time:0.01150441ms +------------------------------------------------------------------------------------- + out_f16: ['-0.11608887 ', '0.92089844 '], time:0.00581026ms + out_f16x2: ['-0.11608887 ', '0.92089844 '], time:0.00510240ms + out_f16x8: ['-0.11608887 ', '0.92089844 '], time:0.00489950ms + out_f16x8pack: ['-0.11608887 ', '0.92089844 '], time:0.00348830ms + out_f16_th: ['-0.1161499 ', '0.921875 '], time:0.00732565ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=1024, K=4096 + out_f32: ['1.75988805 ', '-0.07706322 '], time:0.00934935ms + out_f32x4: ['1.75988805 ', '-0.07706322 '], time:0.00900555ms + out_f32_th: ['1.75988805 ', '-0.07706322 '], time:0.01952362ms +------------------------------------------------------------------------------------- + out_f16: ['1.75878906 ', '-0.07702637 '], time:0.00990725ms + out_f16x2: ['1.75878906 ', '-0.07702637 '], time:0.00903893ms + out_f16x8: ['1.75878906 ', '-0.07702637 '], time:0.00909162ms + out_f16x8pack: ['1.75878906 ', '-0.07702637 '], time:0.00556612ms + out_f16_th: ['1.75878906 ', '-0.07702637 '], time:0.01149321ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=1024 + out_f32: ['-0.2783826 ', '-0.11238551 '], time:0.00662041ms + out_f32x4: ['-0.2783826 ', '-0.11238551 '], time:0.00530052ms + out_f32_th: ['-0.2783826 ', '-0.11238552 '], time:0.01146960ms +------------------------------------------------------------------------------------- + out_f16: ['-0.27856445 ', '-0.11242676 '], time:0.00733256ms + out_f16x2: ['-0.27856445 ', '-0.11242676 '], time:0.00414872ms + out_f16x8: ['-0.27856445 ', '-0.11242676 '], time:0.00447845ms + out_f16x8pack: ['-0.27856445 ', '-0.11242676 '], time:0.00336623ms + out_f16_th: ['-0.27832031 ', '-0.11236572 '], time:0.00891113ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=2048 + out_f32: ['-0.2677404 ', '0.06417678 '], time:0.01659989ms + out_f32x4: ['-0.2677404 ', '0.06417678 '], time:0.02058291ms + out_f32_th: ['-0.2677404 ', '0.06417678 '], time:0.02069759ms +------------------------------------------------------------------------------------- + out_f16: ['-0.26757812 ', '0.06414795 '], time:0.00988817ms + out_f16x2: ['-0.26757812 ', '0.06414795 '], time:0.00977159ms + out_f16x8: ['-0.26757812 ', '0.06414795 '], time:0.00958800ms + out_f16x8pack: ['-0.26757812 ', '0.06414795 '], time:0.00598550ms + out_f16_th: ['-0.26782227 ', '0.06420898 '], time:0.01161885ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=2048, K=4096 + out_f32: ['0.87521237 ', '-0.23885883 '], time:0.01902914ms + out_f32x4: ['0.87521237 ', '-0.23885883 '], time:0.01689911ms + out_f32_th: ['0.87521237 ', '-0.23885885 '], time:0.03624034ms +------------------------------------------------------------------------------------- + out_f16: ['0.875 ', '-0.23876953 '], time:0.01800466ms + out_f16x2: ['0.875 ', '-0.23876953 '], time:0.01659966ms + out_f16x8: ['0.875 ', '-0.23876953 '], time:0.01692343ms + out_f16x8pack: ['0.875 ', '-0.23876953 '], time:0.00913978ms + out_f16_th: ['0.875 ', '-0.23876953 '], time:0.01969314ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=1024 + out_f32: ['1.20880175 ', '0.23003602 '], time:0.01281476ms + out_f32x4: ['1.20880175 ', '0.23003602 '], time:0.00889063ms + out_f32_th: ['1.20880175 ', '0.23003602 '], time:0.02118397ms +------------------------------------------------------------------------------------- + out_f16: ['1.20996094 ', '0.22998047 '], time:0.01336646ms + out_f16x2: ['1.20996094 ', '0.22998047 '], time:0.00935316ms + out_f16x8: ['1.20996094 ', '0.22998047 '], time:0.00812125ms + out_f16x8pack: ['1.20996094 ', '0.22998047 '], time:0.00540519ms + out_f16_th: ['1.20898438 ', '0.22998047 '], time:0.01149321ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=2048 + out_f32: ['-0.20028742 ', '0.00593234 '], time:0.03857923ms + out_f32x4: ['-0.20028742 ', '0.00593234 '], time:0.01753092ms + out_f32_th: ['-0.20028742 ', '0.00593234 '], time:0.03842711ms +------------------------------------------------------------------------------------- + out_f16: ['-0.20019531 ', '0.00593185 '], time:0.01945615ms + out_f16x2: ['-0.20019531 ', '0.00593185 '], time:0.03284836ms + out_f16x8: ['-0.20019531 ', '0.00593185 '], time:0.01644540ms + out_f16x8pack: ['-0.20019531 ', '0.00593185 '], time:0.01503658ms + out_f16_th: ['-0.20019531 ', '0.00593185 '], time:0.03441119ms +------------------------------------------------------------------------------------- +------------------------------------------------------------------------------------- + S=4096, K=4096 + out_f32: ['0.68447012 ', '0.60815662 '], time:0.25561571ms + out_f32x4: ['0.68447012 ', '0.60815662 '], time:0.27445269ms + out_f32_th: ['0.68447006 ', '0.60815662 '], time:0.65941596ms +------------------------------------------------------------------------------------- + out_f16: ['0.68457031 ', '0.60791016 '], time:0.06837082ms + out_f16x2: ['0.68457031 ', '0.60791016 '], time:0.06547904ms + out_f16x8: ['0.68457031 ', '0.60791016 '], time:0.04474044ms + out_f16x8pack: ['0.68457031 ', '0.60791016 '], time:0.03068280ms + out_f16_th: ['0.68457031 ', '0.60791016 '], time:0.03648376ms +------------------------------------------------------------------------------------- + +``` From 2fd14758e3e91e9ec4f3d9419410ed687ba7417d Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 17 Oct 2024 09:18:55 +0800 Subject: [PATCH 3/5] Update swish.cu --- swish/swish.cu | 113 +++++++++++++++++++++++++------------------------ 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/swish/swish.cu b/swish/swish.cu index eebaa9b3..0438dcc8 100644 --- a/swish/swish.cu +++ b/swish/swish.cu @@ -18,79 +18,80 @@ // -------------------------------------- FP32 -------------------------------------- // Swish x: N, y: N y=x*sigmoid(x) __device__ __forceinline__ float swish(float x) { - return x / (1.0f + expf(-x)); + return x / (1.0f + expf(-x)); } __global__ void swish_f32_kernel(float* x, float* y, int N) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < N) y[idx] = swish(x[idx]); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) y[idx] = swish(x[idx]); } __global__ void swish_f32x4_kernel(float* x, float* y, int N) { - int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; - if (idx < N) { - float4 reg_x = FLOAT4(x[idx]); - float4 reg_y; - reg_y.x = swish(reg_x.x); - reg_y.y = swish(reg_x.y); - reg_y.z = swish(reg_x.z); - reg_y.w = swish(reg_x.w); - FLOAT4(y[idx]) = reg_y; - } + int idx = (blockIdx.x * blockDim.x + threadIdx.x) * 4; + if (idx < N) { + float4 reg_x = FLOAT4(x[idx]); + float4 reg_y; + reg_y.x = swish(reg_x.x); + reg_y.y = swish(reg_x.y); + reg_y.z = swish(reg_x.z); + reg_y.w = swish(reg_x.w); + FLOAT4(y[idx]) = reg_y; + } } // -------------------------------------- FP16 -------------------------------------- __device__ __forceinline__ half swish_half(half x) { - return __hmul(x, __hdiv(__float2half(1.0f), __hadd(__float2half(1.0f), hexp(__hneg(x))))); + return __hmul(x, __hdiv( + __float2half(1.0f), __hadd(__float2half(1.0f), hexp(__hneg(x))))); } __global__ void swish_f16_kernel(half* x, half* y, int N) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx < N) y[idx] = swish_half(x[idx]); + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) y[idx] = swish_half(x[idx]); } __global__ void swish_f16x2_kernel(half* x, half* y, int N) { - int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x); - if (idx < N) { - half2 reg_x = HALF2(x[idx]); - half2 reg_y; - reg_y.x = swish_half(reg_x.x); - reg_y.y = swish_half(reg_x.y); - HALF2(y[idx]) = reg_y; - } + int idx = 2 * (blockIdx.x * blockDim.x + threadIdx.x); + if (idx < N) { + half2 reg_x = HALF2(x[idx]); + half2 reg_y; + reg_y.x = swish_half(reg_x.x); + reg_y.y = swish_half(reg_x.y); + HALF2(y[idx]) = reg_y; + } } __global__ void swish_f16x8_kernel(half* x, half* y, int N) { - int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); - half2 reg_x_0 = HALF2(x[idx + 0]); - half2 reg_x_1 = HALF2(x[idx + 2]); - half2 reg_x_2 = HALF2(x[idx + 4]); - half2 reg_x_3 = HALF2(x[idx + 6]); - half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3; - reg_y_0.x = swish_half(reg_x_0.x); - reg_y_0.y = swish_half(reg_x_0.y); - reg_y_1.x = swish_half(reg_x_1.x); - reg_y_1.y = swish_half(reg_x_1.y); - reg_y_2.x = swish_half(reg_x_2.x); - reg_y_2.y = swish_half(reg_x_2.y); - reg_y_3.x = swish_half(reg_x_3.x); - reg_y_3.y = swish_half(reg_x_3.y); - if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; } - if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; } - if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; } - if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; } + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + half2 reg_x_0 = HALF2(x[idx + 0]); + half2 reg_x_1 = HALF2(x[idx + 2]); + half2 reg_x_2 = HALF2(x[idx + 4]); + half2 reg_x_3 = HALF2(x[idx + 6]); + half2 reg_y_0, reg_y_1, reg_y_2, reg_y_3; + reg_y_0.x = swish_half(reg_x_0.x); + reg_y_0.y = swish_half(reg_x_0.y); + reg_y_1.x = swish_half(reg_x_1.x); + reg_y_1.y = swish_half(reg_x_1.y); + reg_y_2.x = swish_half(reg_x_2.x); + reg_y_2.y = swish_half(reg_x_2.y); + reg_y_3.x = swish_half(reg_x_3.x); + reg_y_3.y = swish_half(reg_x_3.y); + if ((idx + 0) < N) { HALF2(y[idx + 0]) = reg_y_0; } + if ((idx + 2) < N) { HALF2(y[idx + 2]) = reg_y_1; } + if ((idx + 4) < N) { HALF2(y[idx + 4]) = reg_y_2; } + if ((idx + 6) < N) { HALF2(y[idx + 6]) = reg_y_3; } } __global__ void swish_f16x8_pack_kernel(half* x, half* y, int N) { - int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); - half pack_x[8], pack_y[8]; - LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); + int idx = 8 * (blockIdx.x * blockDim.x + threadIdx.x); + half pack_x[8], pack_y[8]; + LDST128BITS(pack_x[0]) = LDST128BITS(x[idx]); -#pragma unroll - for (int i = 0; i < 8; i++) { - pack_y[i] = swish_half(pack_x[i]); - } - if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); } + #pragma unroll + for (int i = 0; i < 8; i++) { + pack_y[i] = swish_half(pack_x[i]); + } + if ((idx + 7) < N) { LDST128BITS(y[idx]) = LDST128BITS(pack_y[0]); } } // --------------------- PyTorch bindings for custom kernel ----------------------- @@ -104,8 +105,8 @@ if(((T).options().dtype() != (th_type))) { \ throw std::runtime_error("values must be "#th_type); \ } -#define TORCH_BINDING_SWISH(packed_type, th_type, element_type, n_elements) \ -void swish_##packed_type(torch::Tensor x, torch::Tensor y) { \ +#define TORCH_BINDING_SWISH(packed_type, th_type, element_type, n_elements) \ +void swish_##packed_type(torch::Tensor x, torch::Tensor y) { \ CHECK_TORCH_TENSOR_DTYPE(x, (th_type)) \ CHECK_TORCH_TENSOR_DTYPE(y, (th_type)) \ const int ndim = x.dim(); \ @@ -114,7 +115,7 @@ void swish_##packed_type(torch::Tensor x, torch::Tensor y) { for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ dim3 block(256 / (n_elements)); \ dim3 grid((N + 256 - 1) / 256); \ - swish_##packed_type##_kernel<<>>( \ + swish_##packed_type##_kernel<<>>( \ reinterpret_cast(x.data_ptr()), \ reinterpret_cast(y.data_ptr()), N); \ } else { \ @@ -124,7 +125,7 @@ void swish_##packed_type(torch::Tensor x, torch::Tensor y) { if ((K/(n_elements)) <= 1024) { \ dim3 block(K/(n_elements)); \ dim3 grid(S); \ - swish_##packed_type##_kernel<<>>( \ + swish_##packed_type##_kernel<<>>( \ reinterpret_cast(x.data_ptr()), \ reinterpret_cast(y.data_ptr()), N); \ } else { \ @@ -132,7 +133,7 @@ void swish_##packed_type(torch::Tensor x, torch::Tensor y) { for (int i = 0; i < ndim; ++i) { N *= x.size(i); } \ dim3 block(256 / (n_elements)); \ dim3 grid((N + 256 - 1) / 256); \ - swish_##packed_type##_kernel<<>>( \ + swish_##packed_type##_kernel<<>>( \ reinterpret_cast(x.data_ptr()), \ reinterpret_cast(y.data_ptr()), N); \ } \ @@ -153,4 +154,4 @@ TORCH_BINDING_COMMON_EXTENSION(swish_f16) TORCH_BINDING_COMMON_EXTENSION(swish_f16x2) TORCH_BINDING_COMMON_EXTENSION(swish_f16x8) TORCH_BINDING_COMMON_EXTENSION(swish_f16x8_pack) -} \ No newline at end of file +} From 23c1352d6baf4feca8511e1ba2d26c1b9825d0d4 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 17 Oct 2024 09:19:17 +0800 Subject: [PATCH 4/5] Update README.md --- swish/README.md | 148 ++++++++++++++++++++++++------------------------ 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/swish/README.md b/swish/README.md index 1854ba40..0cdc0362 100644 --- a/swish/README.md +++ b/swish/README.md @@ -18,7 +18,7 @@ ```bash # 只测试Ada架构 不指定默认编译所有架构 耗时较长: Volta, Ampere, Ada, Hopper, ... export TORCH_CUDA_ARCH_LIST=Ada -python3 relu.py +python3 swish.py ``` 输出: @@ -26,111 +26,111 @@ python3 relu.py ```bash ------------------------------------------------------------------------------------- S=1024, K=1024 - out_f32: ['-0.27626503 ', '-0.04989763 '], time:0.00399876ms - out_f32x4: ['-0.27626503 ', '-0.04989763 '], time:0.00343800ms - out_f32_th: ['-0.27626505 ', '-0.04989763 '], time:0.00745869ms + out_f32: ['0.46177661 ', '-0.10888041 '], time:0.01246500ms + out_f32x4: ['0.46177661 ', '-0.10888041 '], time:0.01006508ms + out_f32_th: ['0.46177667 ', '-0.10888041 '], time:0.03012419ms ------------------------------------------------------------------------------------- - out_f16: ['-0.27636719 ', '-0.04989624 '], time:0.00438786ms - out_f16x2: ['-0.27636719 ', '-0.04989624 '], time:0.00286722ms - out_f16x8: ['-0.27636719 ', '-0.04989624 '], time:0.00279593ms - out_f16x8pack: ['-0.27636719 ', '-0.04989624 '], time:0.00246906ms - out_f16_th: ['-0.27636719 ', '-0.04989624 '], time:0.00530934ms + out_f16: ['0.46191406 ', '-0.10894775 '], time:0.01299334ms + out_f16x2: ['0.46191406 ', '-0.10894775 '], time:0.01036119ms + out_f16x8: ['0.46191406 ', '-0.10894775 '], time:0.00979590ms + out_f16x8pack: ['0.46191406 ', '-0.10894775 '], time:0.00972557ms + out_f16_th: ['0.46191406 ', '-0.10888672 '], time:0.02423882ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=1024, K=2048 - out_f32: ['-0.11616411 ', '0.92168581 '], time:0.00554204ms - out_f32x4: ['-0.11616411 ', '0.92168581 '], time:0.00527859ms - out_f32_th: ['-0.11616411 ', '0.92168581 '], time:0.01150441ms + out_f32: ['-0.27797085 ', '0.71514565 '], time:0.01415992ms + out_f32x4: ['-0.27797085 ', '0.71514565 '], time:0.01159716ms + out_f32_th: ['-0.27797085 ', '0.71514559 '], time:0.02964258ms ------------------------------------------------------------------------------------- - out_f16: ['-0.11608887 ', '0.92089844 '], time:0.00581026ms - out_f16x2: ['-0.11608887 ', '0.92089844 '], time:0.00510240ms - out_f16x8: ['-0.11608887 ', '0.92089844 '], time:0.00489950ms - out_f16x8pack: ['-0.11608887 ', '0.92089844 '], time:0.00348830ms - out_f16_th: ['-0.1161499 ', '0.921875 '], time:0.00732565ms + out_f16: ['-0.27807617 ', '0.71582031 '], time:0.01473880ms + out_f16x2: ['-0.27807617 ', '0.71582031 '], time:0.01404881ms + out_f16x8: ['-0.27807617 ', '0.71582031 '], time:0.01127148ms + out_f16x8pack: ['-0.27807617 ', '0.71582031 '], time:0.01101518ms + out_f16_th: ['-0.27807617 ', '0.71533203 '], time:0.02657008ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=1024, K=4096 - out_f32: ['1.75988805 ', '-0.07706322 '], time:0.00934935ms - out_f32x4: ['1.75988805 ', '-0.07706322 '], time:0.00900555ms - out_f32_th: ['1.75988805 ', '-0.07706322 '], time:0.01952362ms + out_f32: ['0.29988611 ', '-0.2541697 '], time:0.01959276ms + out_f32x4: ['0.29988611 ', '-0.2541697 '], time:0.01605868ms + out_f32_th: ['0.29988611 ', '-0.25416973 '], time:0.03745818ms ------------------------------------------------------------------------------------- - out_f16: ['1.75878906 ', '-0.07702637 '], time:0.00990725ms - out_f16x2: ['1.75878906 ', '-0.07702637 '], time:0.00903893ms - out_f16x8: ['1.75878906 ', '-0.07702637 '], time:0.00909162ms - out_f16x8pack: ['1.75878906 ', '-0.07702637 '], time:0.00556612ms - out_f16_th: ['1.75878906 ', '-0.07702637 '], time:0.01149321ms + out_f16: ['0.30004883 ', '-0.25415039 '], time:0.02078271ms + out_f16x2: ['0.30004883 ', '-0.25415039 '], time:0.01729155ms + out_f16x8: ['0.30004883 ', '-0.25415039 '], time:0.01489425ms + out_f16x8pack: ['0.30004883 ', '-0.25415039 '], time:0.01351643ms + out_f16_th: ['0.29980469 ', '-0.25415039 '], time:0.03149080ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=2048, K=1024 - out_f32: ['-0.2783826 ', '-0.11238551 '], time:0.00662041ms - out_f32x4: ['-0.2783826 ', '-0.11238551 '], time:0.00530052ms - out_f32_th: ['-0.2783826 ', '-0.11238552 '], time:0.01146960ms + out_f32: ['-0.07777861 ', '-0.27842814 '], time:0.01640201ms + out_f32x4: ['-0.07777861 ', '-0.27842814 '], time:0.01180029ms + out_f32_th: ['-0.07777861 ', '-0.27842814 '], time:0.02952218ms ------------------------------------------------------------------------------------- - out_f16: ['-0.27856445 ', '-0.11242676 '], time:0.00733256ms - out_f16x2: ['-0.27856445 ', '-0.11242676 '], time:0.00414872ms - out_f16x8: ['-0.27856445 ', '-0.11242676 '], time:0.00447845ms - out_f16x8pack: ['-0.27856445 ', '-0.11242676 '], time:0.00336623ms - out_f16_th: ['-0.27832031 ', '-0.11236572 '], time:0.00891113ms + out_f16: ['-0.07775879 ', '-0.27856445 '], time:0.01758027ms + out_f16x2: ['-0.07775879 ', '-0.27856445 '], time:0.01236153ms + out_f16x8: ['-0.07775879 ', '-0.27856445 '], time:0.01109338ms + out_f16x8pack: ['-0.07775879 ', '-0.27856445 '], time:0.01091790ms + out_f16_th: ['-0.07775879 ', '-0.27856445 '], time:0.02657914ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=2048, K=2048 - out_f32: ['-0.2677404 ', '0.06417678 '], time:0.01659989ms - out_f32x4: ['-0.2677404 ', '0.06417678 '], time:0.02058291ms - out_f32_th: ['-0.2677404 ', '0.06417678 '], time:0.02069759ms + out_f32: ['-0.14754841 ', '-0.21989606 '], time:0.01957679ms + out_f32x4: ['-0.14754841 ', '-0.21989606 '], time:0.01496792ms + out_f32_th: ['-0.14754841 ', '-0.21989603 '], time:0.03751612ms ------------------------------------------------------------------------------------- - out_f16: ['-0.26757812 ', '0.06414795 '], time:0.00988817ms - out_f16x2: ['-0.26757812 ', '0.06414795 '], time:0.00977159ms - out_f16x8: ['-0.26757812 ', '0.06414795 '], time:0.00958800ms - out_f16x8pack: ['-0.26757812 ', '0.06414795 '], time:0.00598550ms - out_f16_th: ['-0.26782227 ', '0.06420898 '], time:0.01161885ms + out_f16: ['-0.14758301 ', '-0.21984863 '], time:0.02085924ms + out_f16x2: ['-0.14758301 ', '-0.21984863 '], time:0.01961517ms + out_f16x8: ['-0.14758301 ', '-0.21984863 '], time:0.01386237ms + out_f16x8pack: ['-0.14758301 ', '-0.21984863 '], time:0.01334929ms + out_f16_th: ['-0.14758301 ', '-0.21984863 '], time:0.03151488ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=2048, K=4096 - out_f32: ['0.87521237 ', '-0.23885883 '], time:0.01902914ms - out_f32x4: ['0.87521237 ', '-0.23885883 '], time:0.01689911ms - out_f32_th: ['0.87521237 ', '-0.23885885 '], time:0.03624034ms + out_f32: ['1.07876182 ', '-0.27844051 '], time:0.03036070ms + out_f32x4: ['1.07876182 ', '-0.27844051 '], time:0.02339220ms + out_f32_th: ['1.07876182 ', '-0.27844048 '], time:0.05310464ms ------------------------------------------------------------------------------------- - out_f16: ['0.875 ', '-0.23876953 '], time:0.01800466ms - out_f16x2: ['0.875 ', '-0.23876953 '], time:0.01659966ms - out_f16x8: ['0.875 ', '-0.23876953 '], time:0.01692343ms - out_f16x8pack: ['0.875 ', '-0.23876953 '], time:0.00913978ms - out_f16_th: ['0.875 ', '-0.23876953 '], time:0.01969314ms + out_f16: ['1.078125 ', '-0.27832031 '], time:0.03291988ms + out_f16x2: ['1.078125 ', '-0.27832031 '], time:0.02590466ms + out_f16x8: ['1.078125 ', '-0.27832031 '], time:0.02027988ms + out_f16x8pack: ['1.078125 ', '-0.27832031 '], time:0.01811814ms + out_f16_th: ['1.07910156 ', '-0.27832031 '], time:0.04083204ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=4096, K=1024 - out_f32: ['1.20880175 ', '0.23003602 '], time:0.01281476ms - out_f32x4: ['1.20880175 ', '0.23003602 '], time:0.00889063ms - out_f32_th: ['1.20880175 ', '0.23003602 '], time:0.02118397ms + out_f32: ['0.31169948 ', '-0.18232882 '], time:0.02427077ms + out_f32x4: ['0.31169948 ', '-0.18232882 '], time:0.01515222ms + out_f32_th: ['0.31169948 ', '-0.18232881 '], time:0.03754425ms ------------------------------------------------------------------------------------- - out_f16: ['1.20996094 ', '0.22998047 '], time:0.01336646ms - out_f16x2: ['1.20996094 ', '0.22998047 '], time:0.00935316ms - out_f16x8: ['1.20996094 ', '0.22998047 '], time:0.00812125ms - out_f16x8pack: ['1.20996094 ', '0.22998047 '], time:0.00540519ms - out_f16_th: ['1.20898438 ', '0.22998047 '], time:0.01149321ms + out_f16: ['0.31152344 ', '-0.18237305 '], time:0.02679300ms + out_f16x2: ['0.31152344 ', '-0.18237305 '], time:0.01617312ms + out_f16x8: ['0.31152344 ', '-0.18237305 '], time:0.01357770ms + out_f16x8pack: ['0.31152344 ', '-0.18237305 '], time:0.01324248ms + out_f16_th: ['0.31152344 ', '-0.18225098 '], time:0.03149295ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=4096, K=2048 - out_f32: ['-0.20028742 ', '0.00593234 '], time:0.03857923ms - out_f32x4: ['-0.20028742 ', '0.00593234 '], time:0.01753092ms - out_f32_th: ['-0.20028742 ', '0.00593234 '], time:0.03842711ms + out_f32: ['1.5033319 ', '0.17473438 '], time:0.03030729ms + out_f32x4: ['1.5033319 ', '0.17473438 '], time:0.02150083ms + out_f32_th: ['1.5033319 ', '0.17473438 '], time:0.05257607ms ------------------------------------------------------------------------------------- - out_f16: ['-0.20019531 ', '0.00593185 '], time:0.01945615ms - out_f16x2: ['-0.20019531 ', '0.00593185 '], time:0.03284836ms - out_f16x8: ['-0.20019531 ', '0.00593185 '], time:0.01644540ms - out_f16x8pack: ['-0.20019531 ', '0.00593185 '], time:0.01503658ms - out_f16_th: ['-0.20019531 ', '0.00593185 '], time:0.03441119ms + out_f16: ['1.50390625 ', '0.17468262 '], time:0.03289509ms + out_f16x2: ['1.50390625 ', '0.17468262 '], time:0.03073120ms + out_f16x8: ['1.50390625 ', '0.17468262 '], time:0.01862860ms + out_f16x8pack: ['1.50390625 ', '0.17468262 '], time:0.01772857ms + out_f16_th: ['1.50390625 ', '0.17468262 '], time:0.04082441ms ------------------------------------------------------------------------------------- ------------------------------------------------------------------------------------- S=4096, K=4096 - out_f32: ['0.68447012 ', '0.60815662 '], time:0.25561571ms - out_f32x4: ['0.68447012 ', '0.60815662 '], time:0.27445269ms - out_f32_th: ['0.68447006 ', '0.60815662 '], time:0.65941596ms -------------------------------------------------------------------------------------- - out_f16: ['0.68457031 ', '0.60791016 '], time:0.06837082ms - out_f16x2: ['0.68457031 ', '0.60791016 '], time:0.06547904ms - out_f16x8: ['0.68457031 ', '0.60791016 '], time:0.04474044ms - out_f16x8pack: ['0.68457031 ', '0.60791016 '], time:0.03068280ms - out_f16_th: ['0.68457031 ', '0.60791016 '], time:0.03648376ms + out_f32: ['-0.05288643 ', '-0.14218464 '], time:0.19254756ms + out_f32x4: ['-0.05288643 ', '-0.14218464 '], time:0.19258785ms + out_f32_th: ['-0.05288643 ', '-0.14218464 '], time:0.48660636ms +------------------------------------------------------------------------------------- + out_f16: ['-0.052948 ', '-0.14221191 '], time:0.05689216ms + out_f16x2: ['-0.052948 ', '-0.14221191 '], time:0.04335928ms + out_f16x8: ['-0.052948 ', '-0.14221191 '], time:0.03096652ms + out_f16x8pack: ['-0.052948 ', '-0.14221191 '], time:0.02706647ms + out_f16_th: ['-0.05288696 ', '-0.14221191 '], time:0.05971408ms ------------------------------------------------------------------------------------- ``` From bf318440c875f37b79c7a812c610125853df5cfb Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Thu, 17 Oct 2024 09:22:18 +0800 Subject: [PATCH 5/5] Update README.md --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index de6ed662..52496262 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,12 @@ | ✔️ [gelu_f16x2](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️| | ✔️ [gelu_f16x8](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️| | ✔️ [gelu_f16x8_pack](./gelu/gelu.cu)|f16|/|[link](./gelu/)|⭐️⭐️| +| ✔️ [swish_f32](./swish/swish.cu)|f32|/|[link](./swish/)|⭐️| +| ✔️ [swish_f32x4](./swish/swish.cu)|f32|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16x2](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16x8](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️| +| ✔️ [swish_f16x8_pack](./swish/swish.cu)|f16|/|[link](./swish/)|⭐️⭐️| | ✔️ [embedding_f32](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f32x4](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️| | ✔️ [embedding_f32x4_pack](./embedding/embedding.cu)|f32|/|[link](./embedding/)|⭐️|