Skip to content

Commit 1dce94c

Browse files
slarenggerganov
andcommitted
ggml : mul_mat_id use the same tensor for all the experts (llama/6387)
* ggml : update mul_mat_id to use the same tensor for all the experts * update cuda * minor * update metal * update test-backend-ops * fix cuda * Update ggml-metal.m Co-authored-by: Georgi Gerganov <[email protected]> * update convert.py * update convert-hf-to-gguf.py * update convert.py for mixtral hf models * Update convert-hf-to-gguf.py Co-authored-by: Georgi Gerganov <[email protected]> * cuda : support non-pow-2 number of experts * allow quantize to work for split and merged experts models in the same way * cleanup + disable mmap automatically with split tensors models * update imatrix * test-backend-ops : test qwen argsort * update grok model loading * llama : add merged experts tensors to the grok tensor map * minor * gguf : bump version * fix quantizing of merged experts * convert-hf-to-gguf.py : update grok (untested) * make linter happy * cuda/argsort : use shared memory instead of pool memory * convert : fix grok tensor names * metal : add support for non-pow-2 argsort * llama : more loader cleanup, better error checking * cuda : fix warning * llama : still use mmap for loading old models, but copy the data to a host buffer * add review note * llama : remove ffn tensor counting + add sanity check ggml-ci * convert : fix handling of n_experts == None ggml-ci * imatrix : fix ncall counters * llama : produce error if imatrix size does not match * quantize : terminate on errors + trace logs ggml-ci * metal : pad shared memory to 16 bytes --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f12e982 commit 1dce94c

File tree

6 files changed

+302
-643
lines changed

6 files changed

+302
-643
lines changed

ggml-cuda.cu

Lines changed: 17 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -401,10 +401,8 @@ GGML_CALL static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t
401401
GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
402402
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
403403

404-
if (tensor->view_src != NULL && tensor->view_offs == 0) {
404+
if (tensor->view_src != NULL) {
405405
assert(tensor->view_src->buffer->buft == buffer->buft);
406-
tensor->backend = tensor->view_src->backend;
407-
tensor->extra = tensor->view_src->extra;
408406
return;
409407
}
410408

@@ -1962,227 +1960,49 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19621960
}
19631961
}
19641962

1965-
#if 0
1966-
template<typename ... Srcs>
1967-
static __global__ void k_compute_batched_ptrs_id(
1968-
const void ** ptrs_src, void ** ptrs_dst,
1969-
int ne12, int ne13,
1970-
int ne23,
1971-
int nb02, int nb03,
1972-
int nb12, int nb13,
1973-
int nb2, int nb3,
1974-
int r2, int r3,
1975-
ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
1976-
const half * src1_f16, half * dst_f16,
1977-
const int32_t * ids, const int id,
1978-
Srcs... src0s) {
1979-
1980-
int i = ids[id];
1981-
1982-
half * src0_f16;
1983-
const void * srcs_ar[] = { (const half *) src0s... };
1984-
if (src0_type == GGML_TYPE_F16) {
1985-
src0_f16 = (half *) srcs_ar[i];
1986-
} else {
1987-
src0_f16 = src0_as_f16;
1988-
if (threadIdx.x == 0 && threadIdx.y == 0) {
1989-
const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
1990-
to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
1991-
}
1992-
}
1993-
1994-
int i13 = blockIdx.x * blockDim.x + threadIdx.x;
1995-
int i12 = blockIdx.y * blockDim.y + threadIdx.y;
1996-
1997-
if (i13 >= ne13 || i12 >= ne12) {
1998-
return;
1999-
}
2000-
2001-
int i03 = i13 / r3;
2002-
int i02 = i12 / r2;
2003-
2004-
ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
2005-
ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
2006-
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
2007-
}
2008-
2009-
static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
2010-
const struct ggml_tensor * ids = dst->src[0];
2011-
const struct ggml_tensor * src1 = dst->src[1];
2012-
const struct ggml_tensor * src00 = dst->src[2];
2013-
2014-
const int id = dst->op_params[0];
2015-
2016-
GGML_ASSERT(!ggml_is_transposed(src00));
2017-
GGML_ASSERT(!ggml_is_transposed(src1));
2018-
2019-
GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
2020-
GGML_ASSERT(src1->type == GGML_TYPE_F32);
2021-
2022-
const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
2023-
const int64_t ne01 = src00->ne[1];
2024-
const int64_t ne02 = src00->ne[2];
2025-
const int64_t ne03 = src00->ne[3];
2026-
2027-
//const int64_t nb01 = src00->nb[1];
2028-
const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
2029-
const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
2030-
2031-
const int64_t ne10 = src1->ne[0];
2032-
const int64_t ne11 = src1->ne[1];
2033-
const int64_t ne12 = src1->ne[2];
2034-
const int64_t ne13 = src1->ne[3];
2035-
2036-
//const int64_t nb11 = src1->nb[1];
2037-
const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
2038-
const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
2039-
2040-
const int64_t ne1 = ggml_nelements(src1);
2041-
const int64_t ne = ggml_nelements(dst);
2042-
2043-
ggml_cuda_set_device(g_main_device);
2044-
cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
2045-
2046-
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
2047-
2048-
//ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
2049-
//void * src0_ddq = src0_extra->data_device[g_main_device];
2050-
//half * src0_as_f16 = (half *) src0_ddq;
2051-
2052-
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
2053-
float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
2054-
2055-
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
2056-
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
2057-
2058-
// convert src1 to fp16
2059-
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
2060-
GGML_ASSERT(to_fp16_cuda != nullptr);
2061-
2062-
size_t src1_as = 0;
2063-
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
2064-
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
2065-
2066-
size_t dst_as = 0;
2067-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
2068-
2069-
GGML_ASSERT(ne12 % ne02 == 0);
2070-
GGML_ASSERT(ne13 % ne03 == 0);
2071-
2072-
// broadcast factors
2073-
const int64_t r2 = ne12/ne02;
2074-
const int64_t r3 = ne13/ne03;
2075-
2076-
const half alpha_f16 = 1.0f;
2077-
const half beta_f16 = 0.0f;
2078-
2079-
// use cublasGemmBatchedEx
2080-
const int ne23 = ne12*ne13;
2081-
2082-
const void ** ptrs_src = nullptr;
2083-
void ** ptrs_dst = nullptr;
2084-
2085-
size_t ptrs_src_s = 0;
2086-
size_t ptrs_dst_s = 0;
2087-
2088-
ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
2089-
ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
2090-
2091-
int64_t src0_ne = ggml_nelements(src00);
2092-
half * src0_as_f16 = nullptr;
2093-
size_t src0_as = 0;
2094-
if (src00->type != GGML_TYPE_F16) {
2095-
src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
2096-
}
2097-
2098-
static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
2099-
dim3 block_dims(ne13, ne12);
2100-
k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
2101-
ptrs_src, ptrs_dst,
2102-
ne12, ne13,
2103-
ne23,
2104-
ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
2105-
nb12, nb13,
2106-
dst->nb[2], dst->nb[3],
2107-
r2, r3,
2108-
src00->type, src0_as_f16, src0_ne,
2109-
src1_as_f16, dst_f16,
2110-
(const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
2111-
dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
2112-
dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
2113-
dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
2114-
dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
2115-
);
2116-
CUDA_CHECK(cudaGetLastError());
2117-
2118-
CUBLAS_CHECK(
2119-
cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
2120-
ne01, ne11, ne10,
2121-
&alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
2122-
(const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
2123-
&beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
2124-
ne23,
2125-
CUBLAS_COMPUTE_16F,
2126-
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
2127-
2128-
if (src0_as != 0) {
2129-
ggml_cuda_pool_free(src0_as_f16, src0_as);
2130-
}
2131-
if (ptrs_src_s != 0) {
2132-
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
2133-
}
2134-
if (ptrs_dst_s != 0) {
2135-
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
2136-
}
2137-
2138-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
2139-
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
2140-
2141-
ggml_cuda_pool_free(src1_as_f16, src1_as);
2142-
ggml_cuda_pool_free(dst_f16, dst_as);
2143-
}
2144-
#endif
2145-
21461963
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2147-
#if 0
2148-
ggml_cuda_mul_mat_id_cublas(dst);
2149-
// TODO: mmq/mmv support
2150-
#endif
21511964
const ggml_tensor * src0 = dst->src[0];
21521965
const ggml_tensor * src1 = dst->src[1];
1966+
const ggml_tensor * ids = dst->src[2];
1967+
1968+
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
21531969

21541970
cudaStream_t stream = ctx.stream();
21551971

21561972
const size_t nb11 = src1->nb[1];
21571973
const size_t nb1 = dst->nb[1];
21581974

2159-
const struct ggml_tensor * ids = src0;
21601975
const int32_t id = ((int32_t *) dst->op_params)[0];
2161-
const int32_t n_as = ((int32_t *) dst->op_params)[1];
1976+
const int32_t n_as = src0->ne[2];
21621977

21631978
std::vector<char> ids_host(ggml_nbytes(ids));
21641979
const char * ids_dev = (const char *) ids->data;
21651980
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
21661981
CUDA_CHECK(cudaStreamSynchronize(stream));
21671982

1983+
ggml_tensor src0_row = *src0;
21681984
ggml_tensor src1_row = *src1;
21691985
ggml_tensor dst_row = *dst;
21701986

1987+
char * src0_original = (char *) src0->data;
21711988
char * src1_original = (char *) src1->data;
21721989
char * dst_original = (char *) dst->data;
21731990

1991+
src0_row.ne[2] = 1;
1992+
src0_row.ne[3] = 1;
1993+
src0_row.nb[3] = src0->nb[2];
1994+
21741995
if (src1->ne[1] == 1) {
21751996
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
21761997
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
21771998

21781999
GGML_ASSERT(row_id >= 0 && row_id < n_as);
21792000

2180-
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
2181-
2001+
src0_row.data = src0_original + row_id*src0->nb[2];
21822002
src1_row.data = src1_original + i01*src1->nb[1];
21832003
dst_row.data = dst_original + i01*dst->nb[1];
21842004

2185-
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
2005+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
21862006
}
21872007
} else {
21882008
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2192,8 +2012,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
21922012
dst_row.data = dst_contiguous.get();
21932013

21942014
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
2195-
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
2196-
21972015
int64_t num_src1_rows = 0;
21982016
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
21992017
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
@@ -2213,6 +2031,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
22132031
continue;
22142032
}
22152033

2034+
src0_row.data = src0_original + row_id*src0->nb[2];
2035+
22162036
src1_row.ne[1] = num_src1_rows;
22172037
dst_row.ne[1] = num_src1_rows;
22182038

@@ -2224,7 +2044,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
22242044
dst_row.nb[2] = num_src1_rows*nb1;
22252045
dst_row.nb[3] = num_src1_rows*nb1;
22262046

2227-
ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
2047+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
22282048

22292049
num_src1_rows = 0;
22302050
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
@@ -2389,7 +2209,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23892209
cudaError_t err = cudaGetLastError();
23902210
if (err != cudaSuccess) {
23912211
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
2392-
GGML_ASSERT(false);
2212+
CUDA_CHECK(err);
23932213
}
23942214

23952215
return true;

ggml-cuda/argsort.cu

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,51 +8,77 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) {
88
}
99

1010
template<ggml_sort_order order>
11-
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
11+
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad) {
1212
// bitonic sort
1313
int col = threadIdx.x;
1414
int row = blockIdx.y;
1515

16-
if (col >= ncols) return;
16+
if (col >= ncols_pad) {
17+
return;
18+
}
1719

1820
const float * x_row = x + row * ncols;
19-
int * dst_row = dst + row * ncols;
21+
extern __shared__ int dst_row[];
2022

2123
// initialize indices
22-
if (col < ncols) {
23-
dst_row[col] = col;
24-
}
24+
dst_row[col] = col;
25+
2526
__syncthreads();
2627

27-
for (int k = 2; k <= ncols; k *= 2) {
28+
for (int k = 2; k <= ncols_pad; k *= 2) {
2829
for (int j = k / 2; j > 0; j /= 2) {
2930
int ixj = col ^ j;
3031
if (ixj > col) {
3132
if ((col & k) == 0) {
32-
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
33+
if (dst_row[col] >= ncols ||
34+
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
35+
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
36+
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
37+
) {
3338
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
3439
}
3540
} else {
36-
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
41+
if (dst_row[ixj] >= ncols ||
42+
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
43+
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
44+
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
45+
) {
3746
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
3847
}
3948
}
4049
}
4150
__syncthreads();
4251
}
4352
}
53+
54+
// copy the result to dst without the padding
55+
if (col < ncols) {
56+
dst[row * ncols + col] = dst_row[col];
57+
}
58+
}
59+
60+
static int next_power_of_2(int x) {
61+
int n = 1;
62+
while (n < x) {
63+
n *= 2;
64+
}
65+
return n;
4466
}
4567

4668
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
4769
// bitonic sort requires ncols to be power of 2
48-
GGML_ASSERT((ncols & (ncols - 1)) == 0);
70+
const int ncols_pad = next_power_of_2(ncols);
4971

50-
const dim3 block_dims(ncols, 1, 1);
72+
const dim3 block_dims(ncols_pad, 1, 1);
5173
const dim3 block_nums(1, nrows, 1);
74+
const size_t shared_mem = ncols_pad * sizeof(int);
75+
76+
GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);
77+
5278
if (order == GGML_SORT_ORDER_ASC) {
53-
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
79+
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
5480
} else if (order == GGML_SORT_ORDER_DESC) {
55-
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
81+
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
5682
} else {
5783
GGML_ASSERT(false);
5884
}

0 commit comments

Comments
 (0)