mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
cuda : performance optimizations (#1530)
* xor hack * block y dim * loop unrolling * Fixed cmake LLAMA_CUDA_BY option * Removed hipblas compatibility code * Define GGML_CUDA_DMMV_BLOCK_Y if not defined * Fewer iters, more ops per iter * Renamed DMMV X/Y compilation options
This commit is contained in:
parent
ac7876ac20
commit
1fcdcc28b1
3 changed files with 110 additions and 64 deletions
|
@ -68,6 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework"
|
|||
option(LLAMA_BLAS "llama: use BLAS" OFF)
|
||||
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
|
||||
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
||||
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
||||
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
|
||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||
|
||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
|
@ -184,6 +186,8 @@ if (LLAMA_CUBLAS)
|
|||
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
|
||||
|
||||
add_compile_definitions(GGML_USE_CUBLAS)
|
||||
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
|
||||
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
|
||||
|
||||
if (LLAMA_STATIC)
|
||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
|
|
12
Makefile
12
Makefile
|
@ -133,9 +133,19 @@ ifdef LLAMA_CUBLAS
|
|||
OBJS += ggml-cuda.o
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
|
||||
ifdef LLAMA_CUDA_DMMV_X
|
||||
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
||||
else
|
||||
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
|
||||
endif # LLAMA_CUDA_DMMV_X
|
||||
ifdef LLAMA_CUDA_DMMV_Y
|
||||
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
|
||||
else
|
||||
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
|
||||
endif # LLAMA_CUDA_DMMV_Y
|
||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
|
||||
endif
|
||||
endif # LLAMA_CUBLAS
|
||||
ifdef LLAMA_CLBLAST
|
||||
CFLAGS += -DGGML_USE_CLBLAST
|
||||
CXXFLAGS += -DGGML_USE_CLBLAST
|
||||
|
|
104
ggml-cuda.cu
104
ggml-cuda.cu
|
@ -83,9 +83,19 @@ typedef struct {
|
|||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
#define WARP_SIZE 32
|
||||
|
||||
#define CUDA_MUL_BLOCK_SIZE 256
|
||||
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
#ifndef GGML_CUDA_DMMV_X
|
||||
#define GGML_CUDA_DMMV_X 32
|
||||
#endif
|
||||
#ifndef GGML_CUDA_DMMV_Y
|
||||
#define GGML_CUDA_DMMV_Y 1
|
||||
#endif
|
||||
|
||||
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
|||
dequantize_kernel(vx, ib, iqs, v0, v1);
|
||||
}
|
||||
|
||||
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
const int row = blockIdx.x;
|
||||
// qk = quantized weights per x block
|
||||
// qr = number of quantized weights per data value in x block
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int iter_stride = 2*GGML_CUDA_DMMV_X;
|
||||
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||
|
||||
__shared__ float tmp[block_size]; // separate sum for each thread
|
||||
tmp[tid] = 0;
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = 0; i < ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
const int ib = (row*ncols + col)/qk; // block index
|
||||
const int iqs = (col%qk)/qr; // quant index
|
||||
for (int i = 0; i < ncols; i += iter_stride) {
|
||||
const int col = i + vals_per_iter*tid;
|
||||
const int ib = (row*ncols + col)/qk; // x block index
|
||||
const int iqs = (col%qk)/qr; // x quant index
|
||||
const int iybs = col - col%qk; // y block start index
|
||||
|
||||
// processing >2 values per i iter is faster for fast GPUs
|
||||
#pragma unroll
|
||||
for (int j = 0; j < vals_per_iter; j += 2) {
|
||||
// process 2 vals per j iter
|
||||
|
||||
// dequantize
|
||||
float v0, v1;
|
||||
dequantize_kernel(vx, ib, iqs, v0, v1);
|
||||
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
|
||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
tmp += v0 * y[iybs + iqs + j/qr + 0];
|
||||
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
|
||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||
}
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
__syncthreads();
|
||||
for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp[0];
|
||||
dst[row] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
|||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
|
||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
|
||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
|
||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
|
||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
|
||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
|
@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
|||
}
|
||||
|
||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
|
||||
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||
GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
|
||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
|
||||
dequantize_mul_mat_vec<1, 1, convert_f16>
|
||||
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
|
|
Loading…
Reference in a new issue