mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
Multi GPU support, CUDA refactor, CUDA scratch buffer (#1703)
* CUDA multi GPU + scratch ggml_cuda_compute_forward Tensor parallelism ggml_cuda_add ggml_cuda_rms_norm ggml_cuda_silu CUDA scratch buffer --main-gpu CLI option
This commit is contained in:
parent
44f906e853
commit
17366df842
12 changed files with 1221 additions and 544 deletions
|
@ -9,6 +9,7 @@
|
|||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
#include <regex>
|
||||
|
||||
#if defined(__APPLE__) && defined(__MACH__)
|
||||
#include <sys/types.h>
|
||||
|
@ -295,6 +296,40 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
|
||||
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
||||
#endif
|
||||
} else if (arg == "--main-gpu" || arg == "-mg") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
params.main_gpu = std::stoi(argv[i]);
|
||||
#else
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
|
||||
#endif
|
||||
} else if (arg == "--tensor-split" || arg == "-ts") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
std::string arg_next = argv[i];
|
||||
|
||||
// split string by , and /
|
||||
const std::regex regex{R"([,/]+)"};
|
||||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||
std::vector<std::string> split_arg{it, {}};
|
||||
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
|
||||
|
||||
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
|
||||
if (i < split_arg.size()) {
|
||||
params.tensor_split[i] = std::stof(split_arg[i]);
|
||||
} else {
|
||||
params.tensor_split[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
#else
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
||||
#endif // GGML_USE_CUBLAS
|
||||
} else if (arg == "--no-mmap") {
|
||||
params.use_mmap = false;
|
||||
} else if (arg == "--mtest") {
|
||||
|
@ -438,6 +473,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
|
||||
fprintf(stderr, " number of layers to store in VRAM\n");
|
||||
fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n");
|
||||
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
||||
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
|
||||
#endif
|
||||
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
||||
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
|
||||
|
@ -483,7 +521,10 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
|
|||
auto lparams = llama_context_default_params();
|
||||
|
||||
lparams.n_ctx = params.n_ctx;
|
||||
lparams.n_batch = params.n_batch;
|
||||
lparams.n_gpu_layers = params.n_gpu_layers;
|
||||
lparams.main_gpu = params.main_gpu;
|
||||
memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float));
|
||||
lparams.seed = params.seed;
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.use_mmap = params.use_mmap;
|
||||
|
|
|
@ -28,6 +28,8 @@ struct gpt_params {
|
|||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||
|
||||
// sampling parameters
|
||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||
|
|
|
@ -286,5 +286,7 @@ These options provide extra functionality and customization when running the LLa
|
|||
- `--verbose-prompt`: Print the prompt before generating text.
|
||||
- `--mtest`: Test the model's functionality by running a series of tests to ensure it's working properly.
|
||||
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
|
||||
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
|
||||
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
|
||||
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
|
||||
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
|
||||
|
|
|
@ -287,6 +287,8 @@ Test();
|
|||
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
|
||||
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
|
||||
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
|
||||
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
|
||||
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
|
||||
- `--embedding`: Enable the embedding mode. **Completion function doesn't work in this mode**.
|
||||
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`;
|
||||
- `--port`: Set the port to listen. Default: `8080`.
|
||||
|
|
|
@ -401,6 +401,10 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params ¶ms)
|
|||
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
|
||||
fprintf(stderr, " -ngl N, --n-gpu-layers N\n");
|
||||
fprintf(stderr, " number of layers to store in VRAM\n");
|
||||
fprintf(stderr, " -ts SPLIT --tensor-split SPLIT\n");
|
||||
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
||||
fprintf(stderr, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
|
||||
fprintf(stderr, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n" );
|
||||
#endif
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
|
@ -502,6 +506,50 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
|
|||
#else
|
||||
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
|
||||
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
|
||||
#endif
|
||||
}
|
||||
else if (arg == "--tensor-split" || arg == "-ts")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
std::string arg_next = argv[i];
|
||||
|
||||
// split string by , and /
|
||||
const std::regex regex{R"([,/]+)"};
|
||||
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
|
||||
std::vector<std::string> split_arg{it, {}};
|
||||
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
|
||||
|
||||
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i)
|
||||
{
|
||||
if (i < split_arg.size())
|
||||
{
|
||||
params.tensor_split[i] = std::stof(split_arg[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
params.tensor_split[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
#else
|
||||
fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
|
||||
#endif // GGML_USE_CUBLAS
|
||||
}
|
||||
else if (arg == "--main-gpu" || arg == "-mg")
|
||||
{
|
||||
if (++i >= argc)
|
||||
{
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
params.main_gpu = std::stoi(argv[i]);
|
||||
#else
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
|
|
1259
ggml-cuda.cu
1259
ggml-cuda.cu
File diff suppressed because it is too large
Load diff
15
ggml-cuda.h
15
ggml-cuda.h
|
@ -1,10 +1,19 @@
|
|||
#pragma once
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_CUDA_MAX_DEVICES 16
|
||||
|
||||
struct ggml_tensor_extra_gpu {
|
||||
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
||||
};
|
||||
|
||||
void ggml_init_cublas(void);
|
||||
void ggml_cuda_set_tensor_split(const float * tensor_split);
|
||||
|
||||
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||
|
@ -15,8 +24,12 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|||
void * ggml_cuda_host_malloc(size_t size);
|
||||
void ggml_cuda_host_free(void * ptr);
|
||||
|
||||
void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
|
||||
void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||
void ggml_cuda_set_main_device(int main_device);
|
||||
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
||||
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -700,7 +700,7 @@ static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t o
|
|||
}
|
||||
|
||||
static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(src1->backend == GGML_BACKEND_CL);
|
||||
GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
@ -814,7 +814,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
size_t y_size;
|
||||
size_t d_size;
|
||||
cl_mem d_X;
|
||||
if (src0->backend == GGML_BACKEND_CL) {
|
||||
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
|
||||
d_X = (cl_mem) src0->data;
|
||||
} else {
|
||||
d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
|
||||
|
@ -825,7 +825,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
// copy data to device
|
||||
if (src0->backend != GGML_BACKEND_CL) {
|
||||
if (src0->backend != GGML_BACKEND_GPU) {
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
|
||||
}
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
|
||||
|
@ -854,7 +854,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
}
|
||||
}
|
||||
|
||||
if (src0->backend != GGML_BACKEND_CL) {
|
||||
if (src0->backend != GGML_BACKEND_GPU) {
|
||||
ggml_cl_pool_free(d_X, x_size);
|
||||
}
|
||||
ggml_cl_pool_free(d_Y, y_size);
|
||||
|
@ -890,7 +890,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
size_t y_size;
|
||||
size_t d_size;
|
||||
cl_mem d_X;
|
||||
if (src0->backend == GGML_BACKEND_CL) {
|
||||
if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
|
||||
d_X = (cl_mem) src0->data;
|
||||
} else {
|
||||
d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size);
|
||||
|
@ -904,7 +904,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
// copy src0 to device
|
||||
if (src0->backend != GGML_BACKEND_CL) {
|
||||
if (src0->backend != GGML_BACKEND_GPU) {
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
|
||||
}
|
||||
|
||||
|
@ -961,7 +961,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
|
|||
}
|
||||
}
|
||||
|
||||
if (src0->backend != GGML_BACKEND_CL) {
|
||||
if (src0->backend != GGML_BACKEND_GPU) {
|
||||
ggml_cl_pool_free(d_X, x_size);
|
||||
}
|
||||
ggml_cl_pool_free(d_Y, y_size);
|
||||
|
@ -1017,7 +1017,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
|
|||
if (src0->backend == GGML_BACKEND_CPU) {
|
||||
events.emplace_back();
|
||||
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
|
||||
} else if (src0->backend == GGML_BACKEND_CL) {
|
||||
} else if (src0->backend == GGML_BACKEND_GPU) {
|
||||
d_Q = (cl_mem) src0->data;
|
||||
} else {
|
||||
GGML_ASSERT(false);
|
||||
|
@ -1102,7 +1102,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
|
|||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 &&
|
||||
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CL)) {
|
||||
((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -1181,7 +1181,7 @@ void ggml_cl_transform_tensor(ggml_tensor * tensor) {
|
|||
CL_CHECK(clFinish(queue));
|
||||
|
||||
tensor->data = dst;
|
||||
tensor->backend = GGML_BACKEND_CL;
|
||||
tensor->backend = GGML_BACKEND_GPU;
|
||||
}
|
||||
|
||||
void ggml_cl_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
|
||||
|
|
75
ggml.c
75
ggml.c
|
@ -3726,26 +3726,6 @@ struct ggml_context_container {
|
|||
struct ggml_context context;
|
||||
};
|
||||
|
||||
//
|
||||
// compute types
|
||||
//
|
||||
|
||||
enum ggml_task_type {
|
||||
GGML_TASK_INIT = 0,
|
||||
GGML_TASK_COMPUTE,
|
||||
GGML_TASK_FINALIZE,
|
||||
};
|
||||
|
||||
struct ggml_compute_params {
|
||||
enum ggml_task_type type;
|
||||
|
||||
int ith, nth;
|
||||
|
||||
// work buffer for all threads
|
||||
size_t wsize;
|
||||
void * wdata;
|
||||
};
|
||||
|
||||
//
|
||||
// ggml state
|
||||
//
|
||||
|
@ -3821,6 +3801,12 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) {
|
|||
return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
|
||||
}
|
||||
|
||||
size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
|
||||
}
|
||||
|
||||
int ggml_blck_size(enum ggml_type type) {
|
||||
return GGML_BLCK_SIZE[type];
|
||||
}
|
||||
|
@ -4248,6 +4234,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
|||
/*.perf_time_us =*/ 0,
|
||||
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
||||
/*.name =*/ { 0 },
|
||||
/*.extra =*/ NULL,
|
||||
/*.pad =*/ { 0 },
|
||||
};
|
||||
|
||||
|
@ -8265,15 +8252,8 @@ static void ggml_compute_forward_mul_f32(
|
|||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (src1->backend == GGML_BACKEND_CUDA) {
|
||||
if (ith == 0) {
|
||||
ggml_cuda_mul(src0, src1, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
if (src1->backend == GGML_BACKEND_CL) {
|
||||
#ifdef GGML_USE_CLBLAST
|
||||
if (src1->backend == GGML_BACKEND_GPU) {
|
||||
if (ith == 0) {
|
||||
ggml_cl_mul(src0, src1, dst);
|
||||
}
|
||||
|
@ -9713,14 +9693,7 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
|
@ -9885,14 +9858,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
|
@ -10097,14 +10063,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
}
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
|
@ -13057,6 +13016,15 @@ static void ggml_compute_forward_map_binary(
|
|||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(params);
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
|
||||
if (skip_cpu) {
|
||||
return;
|
||||
}
|
||||
GGML_ASSERT(tensor->src0->backend == GGML_BACKEND_CPU);
|
||||
GGML_ASSERT(tensor->src1 == NULL || tensor->src1->backend == GGML_BACKEND_CPU);
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
switch (tensor->op) {
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
|
@ -14363,7 +14331,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
|
||||
node->n_tasks = 1; // TODO: this actually is doing nothing
|
||||
// the threads are still spinning
|
||||
cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
|
||||
}
|
||||
else
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
|
|
30
ggml.h
30
ggml.h
|
@ -256,8 +256,8 @@ extern "C" {
|
|||
|
||||
enum ggml_backend {
|
||||
GGML_BACKEND_CPU = 0,
|
||||
GGML_BACKEND_CUDA = 1,
|
||||
GGML_BACKEND_CL = 2,
|
||||
GGML_BACKEND_GPU = 10,
|
||||
GGML_BACKEND_GPU_SPLIT = 20,
|
||||
};
|
||||
|
||||
// model file types
|
||||
|
@ -387,7 +387,9 @@ extern "C" {
|
|||
|
||||
char name[GGML_MAX_NAME];
|
||||
|
||||
char padding[16];
|
||||
void * extra; // extra things e.g. for ggml-cuda.cu
|
||||
|
||||
char padding[4];
|
||||
};
|
||||
|
||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||
|
@ -425,6 +427,25 @@ extern "C" {
|
|||
bool no_alloc; // don't allocate memory for the tensor data
|
||||
};
|
||||
|
||||
|
||||
// compute types
|
||||
enum ggml_task_type {
|
||||
GGML_TASK_INIT = 0,
|
||||
GGML_TASK_COMPUTE,
|
||||
GGML_TASK_FINALIZE,
|
||||
};
|
||||
|
||||
struct ggml_compute_params {
|
||||
enum ggml_task_type type;
|
||||
|
||||
// ith = thread index, nth = number of threads
|
||||
int ith, nth;
|
||||
|
||||
// work buffer for all threads
|
||||
size_t wsize;
|
||||
void * wdata;
|
||||
};
|
||||
|
||||
// misc
|
||||
|
||||
GGML_API void ggml_time_init(void); // call this once at the beginning of the program
|
||||
|
@ -436,9 +457,10 @@ extern "C" {
|
|||
GGML_API void ggml_print_object (const struct ggml_object * obj);
|
||||
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
|
||||
|
||||
GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
|
||||
GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
|
||||
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
|
||||
|
||||
GGML_API int ggml_blck_size (enum ggml_type type);
|
||||
GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
|
||||
|
|
163
llama.cpp
163
llama.cpp
|
@ -59,6 +59,12 @@ static const size_t MB = 1024*1024;
|
|||
// TODO: dynamically determine these sizes
|
||||
// needs modifications in ggml
|
||||
|
||||
typedef void (*offload_func_t)(struct ggml_tensor * tensor);
|
||||
|
||||
void llama_nop(struct ggml_tensor * tensor) { // don't offload by default
|
||||
(void) tensor;
|
||||
}
|
||||
|
||||
static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0()
|
||||
{
|
||||
static std::map<e_model, size_t> k_sizes = {
|
||||
|
@ -173,6 +179,7 @@ struct llama_model {
|
|||
struct ggml_tensor * output;
|
||||
|
||||
std::vector<llama_layer> layers;
|
||||
int n_gpu_layers;
|
||||
|
||||
// context
|
||||
struct ggml_context * ctx = NULL;
|
||||
|
@ -198,6 +205,12 @@ struct llama_model {
|
|||
if (ctx) {
|
||||
ggml_free(ctx);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
for (size_t i = 0; i < tensors_by_name.size(); ++i) {
|
||||
ggml_cuda_free_data(tensors_by_name[i].second);
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -698,6 +711,7 @@ struct llama_model_loader {
|
|||
}
|
||||
ggml_set_name(tensor, lt.name.c_str());
|
||||
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
|
||||
|
||||
tensor->backend = backend;
|
||||
lt.ggml_tensor = tensor;
|
||||
num_ggml_tensors_created++;
|
||||
|
@ -850,7 +864,10 @@ static bool kv_cache_init(
|
|||
struct llama_context_params llama_context_default_params() {
|
||||
struct llama_context_params result = {
|
||||
/*.n_ctx =*/ 512,
|
||||
/*.n_batch =*/ 512,
|
||||
/*.gpu_layers =*/ 0,
|
||||
/*.main_gpu =*/ 0,
|
||||
/*.tensor_split =*/ {0},
|
||||
/*.seed =*/ -1,
|
||||
/*.f16_kv =*/ true,
|
||||
/*.logits_all =*/ false,
|
||||
|
@ -944,7 +961,10 @@ static void llama_model_load_internal(
|
|||
const std::string & fname,
|
||||
llama_context & lctx,
|
||||
int n_ctx,
|
||||
int n_batch,
|
||||
int n_gpu_layers,
|
||||
int main_gpu,
|
||||
const float * tensor_split,
|
||||
ggml_type memory_type,
|
||||
bool use_mmap,
|
||||
bool use_mlock,
|
||||
|
@ -959,6 +979,7 @@ static void llama_model_load_internal(
|
|||
lctx.vocab = std::move(ml->file_loaders.at(0)->vocab);
|
||||
auto & model = lctx.model;
|
||||
model.hparams = ml->file_loaders.at(0)->hparams;
|
||||
model.n_gpu_layers = n_gpu_layers;
|
||||
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
|
@ -1039,17 +1060,22 @@ static void llama_model_load_internal(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CUDA
|
||||
fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
|
||||
ggml_cuda_set_main_device(main_gpu);
|
||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
|
||||
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CL
|
||||
fprintf(stderr, "%s: using OpenCL for GPU acceleration\n", __func__);
|
||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
|
||||
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU
|
||||
#else
|
||||
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU
|
||||
#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
|
||||
#endif
|
||||
|
||||
// prepare memory for the weights
|
||||
size_t vram_total = 0;
|
||||
size_t vram_weights = 0;
|
||||
size_t vram_scratch = 0;
|
||||
{
|
||||
const uint32_t n_embd = hparams.n_embd;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
@ -1064,7 +1090,7 @@ static void llama_model_load_internal(
|
|||
{
|
||||
ggml_backend backend_output;
|
||||
if (n_gpu_layers > int(n_layer)) { // NOLINT
|
||||
backend_output = LLAMA_BACKEND_OFFLOAD;
|
||||
backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
|
||||
} else {
|
||||
backend_output = GGML_BACKEND_CPU;
|
||||
}
|
||||
|
@ -1076,7 +1102,8 @@ static void llama_model_load_internal(
|
|||
|
||||
model.layers.resize(n_layer);
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
|
||||
const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
|
||||
const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
|
@ -1084,19 +1111,19 @@ static void llama_model_load_internal(
|
|||
|
||||
layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
|
||||
|
||||
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend);
|
||||
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend);
|
||||
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend);
|
||||
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend);
|
||||
layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split);
|
||||
layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend_split);
|
||||
layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend_split);
|
||||
layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split);
|
||||
|
||||
layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
|
||||
|
||||
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend);
|
||||
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend);
|
||||
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend);
|
||||
layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split);
|
||||
layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split);
|
||||
layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split);
|
||||
|
||||
if (backend == LLAMA_BACKEND_OFFLOAD) {
|
||||
vram_total +=
|
||||
if (backend == GGML_BACKEND_GPU) {
|
||||
vram_weights +=
|
||||
ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) +
|
||||
ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) +
|
||||
ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
|
||||
|
@ -1113,7 +1140,7 @@ static void llama_model_load_internal(
|
|||
// this is the total memory required to run the inference
|
||||
const size_t mem_required =
|
||||
ctx_size +
|
||||
mmapped_size - vram_total + // weights in VRAM not in memory
|
||||
mmapped_size - vram_weights + // weights in VRAM not in memory
|
||||
MEM_REQ_SCRATCH0().at(model.type) +
|
||||
MEM_REQ_SCRATCH1().at(model.type) +
|
||||
MEM_REQ_EVAL().at (model.type);
|
||||
|
@ -1127,12 +1154,21 @@ static void llama_model_load_internal(
|
|||
|
||||
const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer));
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
vram_scratch = n_batch * MB;
|
||||
ggml_cuda_set_scratch_size(vram_scratch);
|
||||
if (n_gpu_layers > 0) {
|
||||
fprintf(stderr, "%s: allocating batch_size x 1 MB = %ld MB VRAM for the scratch buffer\n",
|
||||
__func__, vram_scratch / MB);
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
|
||||
fprintf(stderr, "%s: offloading %d layers to GPU\n", __func__, n_gpu);
|
||||
if (n_gpu_layers > (int) hparams.n_layer) {
|
||||
fprintf(stderr, "%s: offloading output layer to GPU\n", __func__);
|
||||
}
|
||||
fprintf(stderr, "%s: total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024);
|
||||
fprintf(stderr, "%s: total VRAM used: %zu MB\n",
|
||||
__func__, (vram_weights + vram_scratch + MB - 1) / MB); // round up
|
||||
#else
|
||||
(void) n_gpu_layers;
|
||||
#endif
|
||||
|
@ -1147,6 +1183,8 @@ static void llama_model_load_internal(
|
|||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
{
|
||||
ggml_cuda_set_tensor_split(tensor_split);
|
||||
|
||||
size_t done_size = 0;
|
||||
size_t data_size = 0;
|
||||
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
|
||||
|
@ -1156,7 +1194,8 @@ static void llama_model_load_internal(
|
|||
}
|
||||
}
|
||||
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
|
||||
if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
|
||||
ggml_backend backend = lt.ggml_tensor->backend;
|
||||
if (backend != GGML_BACKEND_GPU && backend != GGML_BACKEND_GPU_SPLIT) {
|
||||
continue;
|
||||
}
|
||||
if (progress_callback) {
|
||||
|
@ -1177,7 +1216,7 @@ static void llama_model_load_internal(
|
|||
}
|
||||
}
|
||||
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
|
||||
if (lt.ggml_tensor->backend != GGML_BACKEND_CL) {
|
||||
if (lt.ggml_tensor->backend != GGML_BACKEND_GPU) {
|
||||
continue;
|
||||
}
|
||||
if (progress_callback) {
|
||||
|
@ -1187,6 +1226,9 @@ static void llama_model_load_internal(
|
|||
done_size += lt.size;
|
||||
}
|
||||
}
|
||||
#else
|
||||
(void) n_batch;
|
||||
(void) tensor_split;
|
||||
#endif
|
||||
|
||||
if (progress_callback) {
|
||||
|
@ -1204,7 +1246,10 @@ static bool llama_model_load(
|
|||
const std::string & fname,
|
||||
llama_context & lctx,
|
||||
int n_ctx,
|
||||
int n_batch,
|
||||
int n_gpu_layers,
|
||||
int main_gpu,
|
||||
float * tensor_split,
|
||||
ggml_type memory_type,
|
||||
bool use_mmap,
|
||||
bool use_mlock,
|
||||
|
@ -1212,8 +1257,8 @@ static bool llama_model_load(
|
|||
llama_progress_callback progress_callback,
|
||||
void *progress_callback_user_data) {
|
||||
try {
|
||||
llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock,
|
||||
vocab_only, progress_callback, progress_callback_user_data);
|
||||
llama_model_load_internal(fname, lctx, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, memory_type,
|
||||
use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data);
|
||||
return true;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "error loading model: %s\n", err.what());
|
||||
|
@ -1260,6 +1305,7 @@ static bool llama_eval_internal(
|
|||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_rot = hparams.n_embd/hparams.n_head;
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
auto & mem_per_token = lctx.mem_per_token;
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
|
@ -1284,7 +1330,17 @@ static bool llama_eval_internal(
|
|||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||
|
||||
const int i_gpu_start = n_layer - n_gpu_layers;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
offload_func_t offload_func = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (il >= i_gpu_start) {
|
||||
offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
lctx.use_buf(ctx0, 0);
|
||||
|
@ -1292,20 +1348,32 @@ static bool llama_eval_internal(
|
|||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_0");
|
||||
|
||||
// cur = cur*attention_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "attention_norm_0");
|
||||
}
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
// offload_func(tmpq);
|
||||
ggml_set_name(tmpq, "tmpq");
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
|
||||
// offload_func(tmpk);
|
||||
ggml_set_name(tmpk, "tmpk");
|
||||
|
||||
struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
ggml_set_name(Kcur, "Kcur");
|
||||
|
||||
struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||
ggml_set_name(Qcur, "Qcur");
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
// compute the transposed [N, n_embd] V matrix
|
||||
|
@ -1313,9 +1381,11 @@ static bool llama_eval_internal(
|
|||
ggml_set_name(Vcur, "Vcur");
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
ggml_set_name(k, "k");
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v));
|
||||
ggml_set_name(v, "v");
|
||||
|
||||
// important: storing RoPE-ed version of K in the KV cache!
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
|
@ -1390,63 +1460,104 @@ static bool llama_eval_internal(
|
|||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].wo,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_wo");
|
||||
}
|
||||
|
||||
lctx.use_buf(ctx0, 1);
|
||||
//ggml_cuda_set_scratch(1);
|
||||
|
||||
struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA);
|
||||
offload_func(inpFF);
|
||||
ggml_set_name(inpFF, "inpFF");
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpFF);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_1");
|
||||
|
||||
// cur = cur*ffn_norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "ffn_norm");
|
||||
}
|
||||
|
||||
struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w3,
|
||||
cur);
|
||||
offload_func(tmp);
|
||||
ggml_set_name(tmp, "result_w3");
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w1,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_w2");
|
||||
|
||||
// SILU activation
|
||||
cur = ggml_silu(ctx0, cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "silu");
|
||||
|
||||
cur = ggml_mul(ctx0, cur, tmp);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "silu_x_result_w3");
|
||||
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].w2,
|
||||
cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_w2");
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, inpFF);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "inpFF_+_result_w2");
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
|
||||
}
|
||||
|
||||
lctx.use_buf(ctx0, 0);
|
||||
//ggml_cuda_set_scratch(0);
|
||||
|
||||
// used at the end to optionally extract the embeddings
|
||||
struct ggml_tensor * embeddings = NULL;
|
||||
|
||||
offload_func_t offload_func = llama_nop;
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (n_gpu_layers > n_layer) {
|
||||
offload_func = ggml_cuda_assign_buffers; // sets the output backend to GPU
|
||||
}
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// norm
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_inpL");
|
||||
|
||||
cur = ggml_rms_norm(ctx0, cur);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "rms_norm_after");
|
||||
|
||||
// cur = cur*norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.norm);
|
||||
offload_func(cur);
|
||||
ggml_set_name(cur, "result_norm");
|
||||
|
||||
embeddings = cur;
|
||||
}
|
||||
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
ggml_set_name(cur, "result_output");
|
||||
|
||||
lctx.use_buf(ctx0, -1);
|
||||
|
||||
|
@ -2366,9 +2477,9 @@ struct llama_context * llama_init_from_file(
|
|||
|
||||
ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
|
||||
params.use_mmap, params.use_mlock, params.vocab_only,
|
||||
params.progress_callback, params.progress_callback_user_data)) {
|
||||
if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_batch, params.n_gpu_layers,
|
||||
params.main_gpu, params.tensor_split, memory_type, params.use_mmap, params.use_mlock,
|
||||
params.vocab_only, params.progress_callback, params.progress_callback_user_data)) {
|
||||
fprintf(stderr, "%s: failed to load model\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
|
|
10
llama.h
10
llama.h
|
@ -1,6 +1,13 @@
|
|||
#ifndef LLAMA_H
|
||||
#define LLAMA_H
|
||||
|
||||
#include "ggml.h"
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
#include "ggml-cuda.h"
|
||||
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
|
||||
#else
|
||||
#define LLAMA_MAX_DEVICES 1
|
||||
#endif // GGML_USE_CUBLAS
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
@ -66,7 +73,10 @@ extern "C" {
|
|||
|
||||
struct llama_context_params {
|
||||
int n_ctx; // text context
|
||||
int n_batch; // prompt processing batch size
|
||||
int n_gpu_layers; // number of layers to store in VRAM
|
||||
int main_gpu; // the GPU that is used for scratch and small tensors
|
||||
float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
|
||||
int seed; // RNG seed, -1 for random
|
||||
|
||||
bool f16_kv; // use fp16 for KV cache
|
||||
|
|
Loading…
Reference in a new issue