mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
ggml : GPU-accelerated token generation (#1412)
* CUDA kernel for q4_0 dequant. + mat. vec. mult. * Added q4_1 via template * Added missing __syncthreads(); * --gpu_layers -> --gpu-layers * Shorter dequantize_mul_mat_vec line * q5_0 dequantize_mul_mat kernel * More readable dequantize_mul_mat_vec logic * dequantize_mul_mat_vec kernels for q5_1, q8_0, f16 * llama : offload "output" tensor to GPU too + coding style fixes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
f954edda93
commit
905d87b70a
8 changed files with 336 additions and 42 deletions
|
@ -277,6 +277,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||||
params.use_color = true;
|
params.use_color = true;
|
||||||
} else if (arg == "--mlock") {
|
} else if (arg == "--mlock") {
|
||||||
params.use_mlock = true;
|
params.use_mlock = true;
|
||||||
|
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.n_gpu_layers = std::stoi(argv[i]);
|
||||||
} else if (arg == "--no-mmap") {
|
} else if (arg == "--no-mmap") {
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
} else if (arg == "--mtest") {
|
} else if (arg == "--mtest") {
|
||||||
|
@ -421,6 +427,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||||
if (llama_mmap_supported()) {
|
if (llama_mmap_supported()) {
|
||||||
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
|
fprintf(stderr, " --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
|
||||||
}
|
}
|
||||||
|
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
|
||||||
|
fprintf(stderr, " number of layers to store in VRAM\n");
|
||||||
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
||||||
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
||||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||||
|
@ -463,14 +471,15 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s
|
||||||
struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
|
struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
|
||||||
auto lparams = llama_context_default_params();
|
auto lparams = llama_context_default_params();
|
||||||
|
|
||||||
lparams.n_ctx = params.n_ctx;
|
lparams.n_ctx = params.n_ctx;
|
||||||
lparams.n_parts = params.n_parts;
|
lparams.n_parts = params.n_parts;
|
||||||
lparams.seed = params.seed;
|
lparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
lparams.f16_kv = params.memory_f16;
|
lparams.seed = params.seed;
|
||||||
lparams.use_mmap = params.use_mmap;
|
lparams.f16_kv = params.memory_f16;
|
||||||
lparams.use_mlock = params.use_mlock;
|
lparams.use_mmap = params.use_mmap;
|
||||||
lparams.logits_all = params.perplexity;
|
lparams.use_mlock = params.use_mlock;
|
||||||
lparams.embedding = params.embedding;
|
lparams.logits_all = params.perplexity;
|
||||||
|
lparams.embedding = params.embedding;
|
||||||
|
|
||||||
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
|
llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams);
|
||||||
|
|
||||||
|
|
|
@ -21,13 +21,14 @@
|
||||||
int32_t get_num_physical_cores();
|
int32_t get_num_physical_cores();
|
||||||
|
|
||||||
struct gpt_params {
|
struct gpt_params {
|
||||||
int32_t seed = -1; // RNG seed
|
int32_t seed = -1; // RNG seed
|
||||||
int32_t n_threads = get_num_physical_cores();
|
int32_t n_threads = get_num_physical_cores();
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
|
int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions)
|
||||||
int32_t n_ctx = 512; // context size
|
int32_t n_ctx = 512; // context size
|
||||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
|
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||||
|
|
287
ggml-cuda.cu
287
ggml-cuda.cu
|
@ -32,9 +32,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
||||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||||
|
typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
|
||||||
|
|
||||||
|
// QK = number of values after dequantization
|
||||||
|
// QR = QK / number of values before dequantization
|
||||||
|
|
||||||
#define QK4_0 32
|
#define QK4_0 32
|
||||||
|
#define QR4_0 2
|
||||||
typedef struct {
|
typedef struct {
|
||||||
float d; // delta
|
float d; // delta
|
||||||
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||||
|
@ -42,6 +48,7 @@ typedef struct {
|
||||||
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding");
|
||||||
|
|
||||||
#define QK4_1 32
|
#define QK4_1 32
|
||||||
|
#define QR4_1 2
|
||||||
typedef struct {
|
typedef struct {
|
||||||
float d; // delta
|
float d; // delta
|
||||||
float m; // min
|
float m; // min
|
||||||
|
@ -50,6 +57,7 @@ typedef struct {
|
||||||
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
|
||||||
|
|
||||||
#define QK5_0 32
|
#define QK5_0 32
|
||||||
|
#define QR5_0 2
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
uint8_t qh[4]; // 5-th bit of quants
|
uint8_t qh[4]; // 5-th bit of quants
|
||||||
|
@ -58,6 +66,7 @@ typedef struct {
|
||||||
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
|
||||||
|
|
||||||
#define QK5_1 32
|
#define QK5_1 32
|
||||||
|
#define QR5_1 2
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
half m; // min
|
half m; // min
|
||||||
|
@ -67,12 +76,100 @@ typedef struct {
|
||||||
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
|
||||||
|
|
||||||
#define QK8_0 32
|
#define QK8_0 32
|
||||||
|
#define QR8_0 1
|
||||||
typedef struct {
|
typedef struct {
|
||||||
float d; // delta
|
float d; // delta
|
||||||
int8_t qs[QK8_0]; // quants
|
int8_t qs[QK8_0]; // quants
|
||||||
} block_q8_0;
|
} block_q8_0;
|
||||||
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||||
|
|
||||||
|
#define CUDA_DMMV_BLOCK_SIZE 32
|
||||||
|
|
||||||
|
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||||
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||||
|
|
||||||
|
const float d = x[ib].d;
|
||||||
|
|
||||||
|
const uint8_t vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
|
const int8_t vi0 = vui & 0xF;
|
||||||
|
const int8_t vi1 = vui >> 4;
|
||||||
|
|
||||||
|
v0 = (vi0 - 8)*d;
|
||||||
|
v1 = (vi1 - 8)*d;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||||
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||||
|
|
||||||
|
const float d = x[ib].d;
|
||||||
|
const float m = x[ib].m;
|
||||||
|
|
||||||
|
const uint8_t vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
|
const int8_t vi0 = vui & 0xF;
|
||||||
|
const int8_t vi1 = vui >> 4;
|
||||||
|
|
||||||
|
v0 = vi0*d + m;
|
||||||
|
v1 = vi1*d + m;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||||
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||||
|
|
||||||
|
const float d = x[ib].d;
|
||||||
|
|
||||||
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
|
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
|
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
|
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
|
||||||
|
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
|
||||||
|
|
||||||
|
v0 = x0*d;
|
||||||
|
v1 = x1*d;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||||
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||||
|
|
||||||
|
const float d = x[ib].d;
|
||||||
|
const float m = x[ib].m;
|
||||||
|
|
||||||
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
|
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
|
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
|
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
|
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
|
v0 = x0*d + m;
|
||||||
|
v1 = x1*d + m;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||||
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||||
|
|
||||||
|
const float d = x[ib].d;
|
||||||
|
|
||||||
|
const int8_t vi0 = x[ib].qs[iqs + 0];
|
||||||
|
const int8_t vi1 = x[ib].qs[iqs + 1];
|
||||||
|
|
||||||
|
v0 = vi0*d;
|
||||||
|
v1 = vi1*d;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
||||||
|
const half * x = (const half *) vx;
|
||||||
|
|
||||||
|
v0 = __half2float(x[ib + 0]);
|
||||||
|
v1 = __half2float(x[ib + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||||
static const int qk = QK4_0;
|
static const int qk = QK4_0;
|
||||||
|
|
||||||
|
@ -173,6 +270,44 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
|
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
|
||||||
|
const int row = blockIdx.x;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
|
__shared__ float tmp[block_size]; // separate sum for each thread
|
||||||
|
tmp[tid] = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < ncols/block_size; i += 2) {
|
||||||
|
const int col = i*block_size + 2*tid;
|
||||||
|
const int ib = (row*ncols + col)/qk; // block index
|
||||||
|
const int iqs = (col%qk)/qr; // quant index
|
||||||
|
const int iybs = col - col%qk; // y block start index
|
||||||
|
|
||||||
|
// dequantize
|
||||||
|
float v0, v1;
|
||||||
|
dequantize_kernel(vx, ib, iqs, v0, v1);
|
||||||
|
|
||||||
|
// matrix multiplication
|
||||||
|
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||||
|
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums and write back result
|
||||||
|
__syncthreads();
|
||||||
|
for (int s=block_size/2; s>0; s>>=1) {
|
||||||
|
if (tid < s) {
|
||||||
|
tmp[tid] += tmp[tid + s];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
dst[row] = tmp[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_0;
|
const int nb = k / QK4_0;
|
||||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
|
@ -198,6 +333,36 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
|
||||||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||||
|
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0>
|
||||||
|
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||||
|
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1>
|
||||||
|
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||||
|
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0>
|
||||||
|
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||||
|
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1>
|
||||||
|
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||||
|
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0>
|
||||||
|
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
|
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
|
||||||
const half * x = (const half *) vx;
|
const half * x = (const half *) vx;
|
||||||
|
@ -211,6 +376,12 @@ static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStre
|
||||||
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
|
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0);
|
||||||
|
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16>
|
||||||
|
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols);
|
||||||
|
}
|
||||||
|
|
||||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
@ -230,8 +401,27 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
|
||||||
|
switch (type) {
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
return dequantize_mul_mat_vec_q4_0_cuda;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
return dequantize_mul_mat_vec_q4_1_cuda;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
return dequantize_mul_mat_vec_q5_0_cuda;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
return dequantize_mul_mat_vec_q5_1_cuda;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
return dequantize_mul_mat_vec_q8_0_cuda;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
return dequantize_mul_mat_vec_q8_0_cuda;
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// buffer pool for cuda
|
// buffer pool for cuda
|
||||||
#define MAX_CUDA_BUFFERS 16
|
#define MAX_CUDA_BUFFERS 256
|
||||||
|
|
||||||
struct scoped_spin_lock {
|
struct scoped_spin_lock {
|
||||||
std::atomic_flag& lock;
|
std::atomic_flag& lock;
|
||||||
|
@ -528,6 +718,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
||||||
const int nb2 = dst->nb[2];
|
const int nb2 = dst->nb[2];
|
||||||
const int nb3 = dst->nb[3];
|
const int nb3 = dst->nb[3];
|
||||||
const ggml_type type = src0->type;
|
const ggml_type type = src0->type;
|
||||||
|
const bool mul_mat_vec = ne11 == 1;
|
||||||
|
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = 0.0f;
|
const float beta = 0.0f;
|
||||||
|
@ -538,12 +729,16 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
||||||
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
||||||
|
|
||||||
size_t x_size, y_size, d_size, q_size;
|
size_t x_size, y_size, d_size, q_size;
|
||||||
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
float * d_X = nullptr;
|
||||||
|
if (!mul_mat_vec) {
|
||||||
|
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||||
|
}
|
||||||
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
||||||
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||||
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
|
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
|
||||||
|
|
||||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
|
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
|
||||||
|
dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
|
||||||
GGML_ASSERT(to_fp32_cuda != nullptr);
|
GGML_ASSERT(to_fp32_cuda != nullptr);
|
||||||
|
|
||||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||||
|
@ -553,31 +748,54 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
||||||
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
|
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
|
||||||
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
|
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
|
||||||
|
|
||||||
float * c_X = d_X + i * x_ne;
|
|
||||||
float * c_Y = d_Y + i * y_ne;
|
float * c_Y = d_Y + i * y_ne;
|
||||||
float * c_D = d_D + i * d_ne;
|
float * c_D = d_D + i * d_ne;
|
||||||
char * c_Q = d_Q + i * q_sz;
|
char * c_Q = d_Q + i * q_sz;
|
||||||
|
|
||||||
// copy src0 and convert to fp32 on device
|
// copy src0 to device if necessary
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
if (src0->backend == GGML_BACKEND_CPU) {
|
||||||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
||||||
CUDA_CHECK(cudaGetLastError());
|
} else if (src0->backend == GGML_BACKEND_CUDA) {
|
||||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
c_Q = ((char *) src0->data) + i * q_sz;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
}
|
||||||
|
if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
|
||||||
|
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||||
|
|
||||||
// copy src1 to device
|
// copy src1 to device
|
||||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||||
|
|
||||||
// wait for conversion
|
// wait for data
|
||||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||||
|
|
||||||
// compute
|
// compute
|
||||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||||
CUBLAS_CHECK(
|
CUDA_CHECK(cudaGetLastError());
|
||||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
|
||||||
ne01, ne11, ne10,
|
} else { // general dequantization kernel + cuBLAS matrix matrix multiplication
|
||||||
&alpha, c_X, ne00,
|
float * c_X = d_X + i * x_ne;
|
||||||
c_Y, ne10,
|
|
||||||
&beta, c_D, ne01));
|
// convert src0 to fp32 on device
|
||||||
|
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||||
|
|
||||||
|
// copy src1 to device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||||
|
|
||||||
|
// wait for conversion
|
||||||
|
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||||
|
|
||||||
|
// compute
|
||||||
|
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||||
|
CUBLAS_CHECK(
|
||||||
|
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||||
|
ne01, ne11, ne10,
|
||||||
|
&alpha, c_X, ne00,
|
||||||
|
c_Y, ne10,
|
||||||
|
&beta, c_D, ne01));
|
||||||
|
}
|
||||||
|
|
||||||
// copy dst to host
|
// copy dst to host
|
||||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||||
|
@ -586,7 +804,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
ggml_cuda_pool_free(d_X, x_size);
|
if (!mul_mat_vec) {
|
||||||
|
ggml_cuda_pool_free(d_X, x_size);
|
||||||
|
}
|
||||||
ggml_cuda_pool_free(d_Y, y_size);
|
ggml_cuda_pool_free(d_Y, y_size);
|
||||||
ggml_cuda_pool_free(d_D, d_size);
|
ggml_cuda_pool_free(d_D, d_size);
|
||||||
ggml_cuda_pool_free(d_Q, q_size);
|
ggml_cuda_pool_free(d_Q, q_size);
|
||||||
|
@ -602,8 +822,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
||||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||||
src1->type == GGML_TYPE_F32 &&
|
src1->type == GGML_TYPE_F32 &&
|
||||||
dst->type == GGML_TYPE_F32 &&
|
dst->type == GGML_TYPE_F32 &&
|
||||||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -655,3 +874,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
|
||||||
|
const int64_t ne0 = tensor->ne[0];
|
||||||
|
const int64_t ne1 = tensor->ne[1];
|
||||||
|
const int64_t ne2 = tensor->ne[2];
|
||||||
|
const int64_t ne3 = tensor->ne[3];
|
||||||
|
|
||||||
|
const ggml_type type = tensor->type;
|
||||||
|
const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
|
||||||
|
|
||||||
|
size_t q_size;
|
||||||
|
char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
|
||||||
|
|
||||||
|
cudaStream_t cudaStream2 = g_cudaStreams2[0];
|
||||||
|
|
||||||
|
// copy tensor to device
|
||||||
|
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2));
|
||||||
|
CUDA_CHECK(cudaDeviceSynchronize());
|
||||||
|
|
||||||
|
tensor->data = d_Q;
|
||||||
|
tensor->backend = GGML_BACKEND_CUDA;
|
||||||
|
}
|
||||||
|
|
|
@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
||||||
void * ggml_cuda_host_malloc(size_t size);
|
void * ggml_cuda_host_malloc(size_t size);
|
||||||
void ggml_cuda_host_free(void * ptr);
|
void ggml_cuda_host_free(void * ptr);
|
||||||
|
|
||||||
|
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
1
ggml.c
1
ggml.c
|
@ -3882,6 +3882,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
||||||
|
|
||||||
*result = (struct ggml_tensor) {
|
*result = (struct ggml_tensor) {
|
||||||
/*.type =*/ type,
|
/*.type =*/ type,
|
||||||
|
/*.backend =*/ GGML_BACKEND_CPU,
|
||||||
/*.n_dims =*/ n_dims,
|
/*.n_dims =*/ n_dims,
|
||||||
/*.ne =*/ { 1, 1, 1, 1 },
|
/*.ne =*/ { 1, 1, 1, 1 },
|
||||||
/*.nb =*/ { 0, 0, 0, 0 },
|
/*.nb =*/ { 0, 0, 0, 0 },
|
||||||
|
|
8
ggml.h
8
ggml.h
|
@ -243,6 +243,11 @@ extern "C" {
|
||||||
GGML_TYPE_COUNT,
|
GGML_TYPE_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum ggml_backend {
|
||||||
|
GGML_BACKEND_CPU = 0,
|
||||||
|
GGML_BACKEND_CUDA = 1,
|
||||||
|
};
|
||||||
|
|
||||||
// model file types
|
// model file types
|
||||||
enum ggml_ftype {
|
enum ggml_ftype {
|
||||||
GGML_FTYPE_UNKNOWN = -1,
|
GGML_FTYPE_UNKNOWN = -1,
|
||||||
|
@ -333,6 +338,7 @@ extern "C" {
|
||||||
// n-dimensional tensor
|
// n-dimensional tensor
|
||||||
struct ggml_tensor {
|
struct ggml_tensor {
|
||||||
enum ggml_type type;
|
enum ggml_type type;
|
||||||
|
enum ggml_backend backend;
|
||||||
|
|
||||||
int n_dims;
|
int n_dims;
|
||||||
int64_t ne[GGML_MAX_DIMS]; // number of elements
|
int64_t ne[GGML_MAX_DIMS]; // number of elements
|
||||||
|
@ -363,7 +369,7 @@ extern "C" {
|
||||||
|
|
||||||
char name[32];
|
char name[32];
|
||||||
|
|
||||||
char padding[8]; // TODO: remove and add padding to name?
|
char padding[9]; // TODO: remove and add padding to name?
|
||||||
};
|
};
|
||||||
|
|
||||||
// computation graph
|
// computation graph
|
||||||
|
|
37
llama.cpp
37
llama.cpp
|
@ -9,6 +9,9 @@
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
|
@ -810,6 +813,7 @@ struct llama_context_params llama_context_default_params() {
|
||||||
struct llama_context_params result = {
|
struct llama_context_params result = {
|
||||||
/*.n_ctx =*/ 512,
|
/*.n_ctx =*/ 512,
|
||||||
/*.n_parts =*/ -1,
|
/*.n_parts =*/ -1,
|
||||||
|
/*.gpu_layers =*/ 0,
|
||||||
/*.seed =*/ -1,
|
/*.seed =*/ -1,
|
||||||
/*.f16_kv =*/ false,
|
/*.f16_kv =*/ false,
|
||||||
/*.logits_all =*/ false,
|
/*.logits_all =*/ false,
|
||||||
|
@ -876,6 +880,7 @@ static void llama_model_load_internal(
|
||||||
const std::string & fname,
|
const std::string & fname,
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
|
int n_gpu_layers,
|
||||||
ggml_type memory_type,
|
ggml_type memory_type,
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool use_mlock,
|
bool use_mlock,
|
||||||
|
@ -1022,6 +1027,33 @@ static void llama_model_load_internal(
|
||||||
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
|
ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL);
|
||||||
|
|
||||||
model.mapping = std::move(ml->mapping);
|
model.mapping = std::move(ml->mapping);
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
{
|
||||||
|
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu);
|
||||||
|
|
||||||
|
size_t vram_total = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_gpu; ++i) {
|
||||||
|
const auto & layer = model.layers[i];
|
||||||
|
|
||||||
|
ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq);
|
||||||
|
ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk);
|
||||||
|
ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv);
|
||||||
|
ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo);
|
||||||
|
ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1);
|
||||||
|
ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2);
|
||||||
|
ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3);
|
||||||
|
}
|
||||||
|
if (n_gpu_layers > (int) hparams.n_layer) {
|
||||||
|
fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__);
|
||||||
|
ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output);
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// loading time will be recalculate after the first eval, so
|
// loading time will be recalculate after the first eval, so
|
||||||
// we take page faults deferred by mmap() into consideration
|
// we take page faults deferred by mmap() into consideration
|
||||||
|
@ -1032,6 +1064,7 @@ static bool llama_model_load(
|
||||||
const std::string & fname,
|
const std::string & fname,
|
||||||
llama_context & lctx,
|
llama_context & lctx,
|
||||||
int n_ctx,
|
int n_ctx,
|
||||||
|
int n_gpu_layers,
|
||||||
ggml_type memory_type,
|
ggml_type memory_type,
|
||||||
bool use_mmap,
|
bool use_mmap,
|
||||||
bool use_mlock,
|
bool use_mlock,
|
||||||
|
@ -1039,7 +1072,7 @@ static bool llama_model_load(
|
||||||
llama_progress_callback progress_callback,
|
llama_progress_callback progress_callback,
|
||||||
void *progress_callback_user_data) {
|
void *progress_callback_user_data) {
|
||||||
try {
|
try {
|
||||||
llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock,
|
llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock,
|
||||||
vocab_only, progress_callback, progress_callback_user_data);
|
vocab_only, progress_callback, progress_callback_user_data);
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::string & err) {
|
} catch (const std::string & err) {
|
||||||
|
@ -2111,7 +2144,7 @@ struct llama_context * llama_init_from_file(
|
||||||
|
|
||||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||||
|
|
||||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type,
|
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
|
||||||
params.use_mmap, params.use_mlock, params.vocab_only,
|
params.use_mmap, params.use_mlock, params.vocab_only,
|
||||||
params.progress_callback, params.progress_callback_user_data)) {
|
params.progress_callback, params.progress_callback_user_data)) {
|
||||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||||
|
|
7
llama.h
7
llama.h
|
@ -54,9 +54,10 @@ extern "C" {
|
||||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||||
|
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
int n_ctx; // text context
|
int n_ctx; // text context
|
||||||
int n_parts; // -1 for default
|
int n_parts; // -1 for default
|
||||||
int seed; // RNG seed, -1 for random
|
int n_gpu_layers; // number of layers to store in VRAM
|
||||||
|
int seed; // RNG seed, -1 for random
|
||||||
|
|
||||||
bool f16_kv; // use fp16 for KV cache
|
bool f16_kv; // use fp16 for KV cache
|
||||||
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
bool logits_all; // the llama_eval() call computes all logits, not just the last one
|
||||||
|
|
Loading…
Reference in a new issue