Skip to content

ggml_top_k cuda error on tensors greater than 1024 #1367

@Codes4Fun

Description

@Codes4Fun

here is a test case:

#define TOP_K 8000
void top_k_cuda_test() {
    ggml_backend_t backend = NULL;

    //backend = ggml_backend_cpu_init();
    backend = ggml_backend_cuda_init(0);
    assert( backend );

    // create context
    auto ctx = ggml_init({
        /*.mem_size   =*/ ggml_tensor_overhead()*GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(),
        /*.mem_buffer =*/ NULL,
        /*.no_alloc   =*/ true,
    });

    auto probs = ggml_arange( ctx, 0, TOP_K, 1 );
    auto result = ggml_top_k( ctx, probs, 512 );

    auto gf = ggml_new_graph(ctx);
    ggml_build_forward_expand(gf, result);

    auto buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
    ggml_backend_graph_compute(backend, gf);

    std::vector<int> out_data(ggml_nelements(result));
    ggml_backend_tensor_get(result, out_data.data(), 0, ggml_nbytes(result));

    printf("out_data %d %d %d %d ...\n", out_data[0], out_data[1], out_data[2], out_data[3] );

    printf("press enter to continue");
    getchar();

    // release backend memory and free backend
    ggml_backend_buffer_free(buffer);
    ggml_free(ctx);
    ggml_backend_free(backend);
}

it seems that cuda doesn't support more than 1024 threads, and the code creates 1 thread per column, so my test does 8000 threads.

I created and tested this work around in argsort.cu, that works in a few test cases I have:

template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, int nchunk) {
    // bitonic sort
    int col_start = threadIdx.x * nchunk;
    int col_end = col_start + nchunk;
    if (col_end > ncols_pad) col_end = ncols_pad;
    int row = blockIdx.y;

    if (col_start >= ncols_pad) {
        return;
    }

    extern __shared__ int dst_row[];

    // initialize indices
    for (int col = col_start; col < col_end; col++) {
        dst_row[col] = col;
    }

    __syncthreads();

    const float * x_row = x + row * ncols;

    for (int k = 2; k <= ncols_pad; k *= 2) {
        for (int j = k / 2; j > 0; j /= 2) {
            for (int col = col_start; col < col_end; col++) {
                int ixj = col ^ j;
                if (ixj > col) {
                    if ((col & k) == 0) {
                        if (dst_row[col] >= ncols ||
                            (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
                                x_row[dst_row[col]] > x_row[dst_row[ixj]] :
                                x_row[dst_row[col]] < x_row[dst_row[ixj]]))
                        ) {
                            ggml_cuda_swap(dst_row[col], dst_row[ixj]);
                        }
                    } else {
                        if (dst_row[ixj] >= ncols ||
                            (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
                                x_row[dst_row[col]] < x_row[dst_row[ixj]] :
                                x_row[dst_row[col]] > x_row[dst_row[ixj]]))
                        ) {
                            ggml_cuda_swap(dst_row[col], dst_row[ixj]);
                        }
                    }
                }
            }
            __syncthreads();
        }
    }

    // copy the result to dst without the padding
    for (int col = col_start; col < col_end; col++) {
        if (col < ncols) {
            dst[row * ncols + col] = dst_row[col];
        }
    }
}

static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
    // bitonic sort requires ncols to be power of 2
    const int ncols_pad = next_power_of_2(ncols);
    int nchunks, dim0;
    if (ncols_pad > 1024) {
        dim0 = 1024;
        nchunks = ncols_pad / 1024;
    } else {
        dim0 = ncols_pad;
        nchunks = 1;
    }

    const dim3 block_dims(dim0, 1, 1);
    const dim3 block_nums(1, nrows, 1);
    const size_t shared_mem = ncols_pad * sizeof(int);

    // FIXME: this limit could be raised by ~2-4x on Ampere or newer
    GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb);

    if (order == GGML_SORT_ORDER_ASC) {
        k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, nchunks);
    } else if (order == GGML_SORT_ORDER_DESC) {
        k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad, nchunks);
    } else {
        GGML_ABORT("fatal error");
    }
}

but I am not as familiar with if there is a better way to do this, but it simply assigns more than one column per thread if the columns exceeds the thread limit of 1024, and I believe it makes sure the threads don't stomp on each others changes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions