From 5656d10599bd756dc0f17284e418e704200b43f3 Mon Sep 17 00:00:00 2001 From: Evan Miller Date: Mon, 10 Jul 2023 11:49:56 -0400 Subject: [PATCH] mpi : add support for distributed inference via MPI (#2099) * MPI support, first cut * fix warnings, update README * fixes * wrap includes * PR comments * Update CMakeLists.txt * Add GH workflow, fix test * Add info to README * mpi : trying to move more MPI stuff into ggml-mpi (WIP) (#2099) * mpi : add names for layer inputs + prep ggml_mpi_graph_compute() * mpi : move all MPI logic into ggml-mpi Not tested yet * mpi : various fixes - communication now works but results are wrong * mpi : fix output tensor after MPI compute (still not working) * mpi : fix inference * mpi : minor * Add OpenMPI to GH action * [mpi] continue-on-error: true * mpi : fix after master merge * [mpi] Link MPI C++ libraries to fix OpenMPI * tests : fix new llama_backend API * [mpi] use MPI_INT32_T * mpi : factor out recv / send in functions and reuse * mpi : extend API to allow usage with outer backends (e.g. Metal) --------- Co-authored-by: Georgi Gerganov --- .github/workflows/build.yml | 34 ++++ .gitignore | 1 + CMakeLists.txt | 24 +++ Makefile | 9 ++ README.md | 39 +++++ examples/embd-input/embd-input-lib.cpp | 2 +- examples/embedding/embedding.cpp | 4 +- examples/main/main.cpp | 4 +- examples/perplexity/perplexity.cpp | 4 +- examples/quantize/quantize.cpp | 4 +- examples/server/server.cpp | 4 +- examples/simple/simple.cpp | 4 +- ggml-metal.m | 1 + ggml-mpi.c | 216 +++++++++++++++++++++++++ ggml-mpi.h | 39 +++++ llama.cpp | 98 +++++++---- llama.h | 4 +- tests/test-tokenizer-0.cpp | 4 + 18 files changed, 460 insertions(+), 35 deletions(-) create mode 100644 ggml-mpi.c create mode 100644 ggml-mpi.h diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f6a2dd6..b6e21b4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -104,6 +104,40 @@ jobs: cd build ctest --verbose --timeout 900 + ubuntu-latest-cmake-mpi: + runs-on: ubuntu-latest + + continue-on-error: true + + strategy: + matrix: + mpi_library: [mpich, libopenmpi-dev] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v1 + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential ${{ matrix.mpi_library }} + + - name: Build + id: cmake_build + run: | + mkdir build + cd build + cmake -DLLAMA_MPI=ON .. + cmake --build . --config Release + + - name: Test + id: cmake_test + run: | + cd build + ctest --verbose + macOS-latest-make: runs-on: macos-latest diff --git a/.gitignore b/.gitignore index 4fccec3..faec869 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ build-static/ build-cublas/ build-opencl/ build-metal/ +build-mpi/ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ diff --git a/CMakeLists.txt b/CMakeLists.txt index eed7b1b..cf6cd34 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,7 @@ option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_METAL "llama: use Metal" OFF) +option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF) @@ -308,6 +309,28 @@ if (LLAMA_METAL) ) endif() +if (LLAMA_MPI) + cmake_minimum_required(VERSION 3.10) + find_package(MPI) + if (MPI_C_FOUND) + message(STATUS "MPI found") + set(GGML_SOURCES_MPI ggml-mpi.c ggml-mpi.h) + add_compile_definitions(GGML_USE_MPI) + add_compile_definitions(${MPI_C_COMPILE_DEFINITIONS}) + set(cxx_flags ${cxx_flags} -Wno-cast-qual) + set(c_flags ${c_flags} -Wno-cast-qual) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_C_LIBRARIES}) + set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${MPI_C_INCLUDE_DIRS}) + # Even if you're only using the C header, C++ programs may bring in MPI + # C++ functions, so more linkage is needed + if (MPI_CXX_FOUND) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${MPI_CXX_LIBRARIES}) + endif() + else() + message(WARNING "MPI not found") + endif() +endif() + if (LLAMA_CLBLAST) find_package(CLBlast) if (CLBlast_FOUND) @@ -476,6 +499,7 @@ add_library(ggml OBJECT ${GGML_SOURCES_CUDA} ${GGML_SOURCES_OPENCL} ${GGML_SOURCES_METAL} + ${GGML_SOURCES_MPI} ${GGML_SOURCES_EXTRA} ) diff --git a/Makefile b/Makefile index 6068cbe..f887ed6 100644 --- a/Makefile +++ b/Makefile @@ -147,6 +147,15 @@ ifndef LLAMA_NO_ACCELERATE endif endif # LLAMA_NO_ACCELERATE +ifdef LLAMA_MPI + CFLAGS += -DGGML_USE_MPI -Wno-cast-qual + CXXFLAGS += -DGGML_USE_MPI -Wno-cast-qual + OBJS += ggml-mpi.o + +ggml-mpi.o: ggml-mpi.c ggml-mpi.h + $(CC) $(CFLAGS) -c $< -o $@ +endif # LLAMA_MPI + ifdef LLAMA_OPENBLAS CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas -I/usr/include/openblas LDFLAGS += -lopenblas diff --git a/README.md b/README.md index daa71c2..63457b6 100644 --- a/README.md +++ b/README.md @@ -268,6 +268,45 @@ Any value larger than 0 will offload the computation to the GPU. For example: ./main -m ./models/7B/ggml-model-q4_0.bin -n 128 -ngl 1 ``` +### MPI Build + +MPI lets you distribute the computation over a cluster of machines. Because of the serial nature of LLM prediction, this won't yield any end-to-end speed-ups, but it will let you run larger models than would otherwise fit into RAM on a single machine. + +First you will need MPI libraries installed on your system. The two most popular (only?) options are [MPICH](https://www.mpich.org) and [OpenMPI](https://www.open-mpi.org). Either can be installed with a package manager (`apt`, Homebrew, MacPorts, etc). + +Next you will need to build the project with `LLAMA_MPI` set to true on all machines; if you're building with `make`, you will also need to specify an MPI-capable compiler (when building with CMake, this is configured automatically): + +- Using `make`: + + ```bash + make CC=mpicc CXX=mpicxx LLAMA_MPI=1 + ``` + +- Using `CMake`: + + ```bash + cmake -S . -B build -DLLAMA_MPI=ON + ``` + +Once the programs are built, download/convert the weights on all of the machines in your cluster. The paths to the weights and programs should be identical on all machines. + +Next, ensure password-less SSH access to each machine from the primary host, and create a `hostfile` with a list of the hostnames and their relative "weights" (slots). If you want to use localhost for computation, use its local subnet IP address rather than the loopback address or "localhost". + +Here is an example hostfile: + +``` +192.168.0.1:2 +malvolio.local:1 +``` + +The above will distribute the computation across 2 processes on the first host and 1 process on the second host. Each process will use roughly an equal amount of RAM. Try to keep these numbers small, as inter-process (intra-host) communication is expensive. + +Finally, you're ready to run a computation using `mpirun`: + +```bash +mpirun -hostfile hostfile -n 3 ./main -m ./models/7B/ggml-model-q4_0.bin -n 128 +``` + ### BLAS Build Building the program with BLAS support may lead to some performance improvements in prompt processing using batch sizes higher than 32 (the default is 512). BLAS doesn't affect the normal generation performance. There are currently three different implementations of it: diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 5fa4942..2656382 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -34,7 +34,7 @@ struct MyModel* create_mymodel(int argc, char ** argv) { } fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 03e801c..5192d6d 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,7 +35,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -93,5 +93,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_backend_free(); + return 0; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0f6391a..07d8fc6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -105,7 +105,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -671,5 +671,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_backend_free(); + return 0; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index fd4b03c..7e120ff 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -147,7 +147,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -172,5 +172,7 @@ int main(int argc, char ** argv) { llama_free(ctx); llama_free_model(model); + llama_backend_free(); + return 0; } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1eb0f75..797d2f0 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -180,7 +180,7 @@ int main(int argc, char ** argv) { usage(argv[0]); } - llama_init_backend(false); + llama_backend_init(false); // parse command line arguments const std::string fname_inp = argv[arg_idx]; @@ -257,5 +257,7 @@ int main(int argc, char ** argv) { printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0); } + llama_backend_free(); + return 0; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2cbfc00..296c5d6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1079,7 +1079,7 @@ int main(int argc, char **argv) params.model_alias = params.model; } - llama_init_backend(params.numa); + llama_backend_init(params.numa); LOG_INFO("build info", {{"build", BUILD_NUMBER}, {"commit", BUILD_COMMIT}}); @@ -1309,5 +1309,7 @@ int main(int argc, char **argv) return 1; } + llama_backend_free(); + return 0; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 2d913ce..aa2c435 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -66,7 +66,7 @@ int main(int argc, char ** argv) // Init LLM : //--------------------------------- - llama_init_backend(params.numa); + llama_backend_init(params.numa); llama_model * model; llama_context * ctx; @@ -173,6 +173,8 @@ int main(int argc, char ** argv) llama_free( ctx ); llama_free_model( model ); + llama_backend_free(); + return 0; } diff --git a/ggml-metal.m b/ggml-metal.m index 3f15f79..6473644 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -450,6 +450,7 @@ void ggml_metal_graph_compute( //} switch (dst->op) { + case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: diff --git a/ggml-mpi.c b/ggml-mpi.c new file mode 100644 index 0000000..872e808 --- /dev/null +++ b/ggml-mpi.c @@ -0,0 +1,216 @@ +#include "ggml-mpi.h" + +#include "ggml.h" + +#include + +#include +#include + +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +#define UNUSED GGML_UNUSED + +struct ggml_mpi_context { + int rank; + int size; +}; + +void ggml_mpi_backend_init(void) { + MPI_Init(NULL, NULL); +} + +void ggml_mpi_backend_free(void) { + MPI_Finalize(); +} + +struct ggml_mpi_context * ggml_mpi_init(void) { + struct ggml_mpi_context * ctx = calloc(1, sizeof(struct ggml_mpi_context)); + + MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); + MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); + + return ctx; +} + +void ggml_mpi_free(struct ggml_mpi_context * ctx) { + free(ctx); +} + +int ggml_mpi_rank(struct ggml_mpi_context * ctx) { + return ctx->rank; +} + +void ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + int * n_tokens, + int * n_past, + int * n_threads) { + UNUSED(ctx_mpi); + + // synchronize the worker node parameters with the root node + MPI_Barrier(MPI_COMM_WORLD); + + MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); + MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); +} + +static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { + struct ggml_tensor * t = ggml_graph_get_tensor(gf, name); + if (t == NULL) { + fprintf(stderr, "%s: tensor %s not found\n", __func__, name); + return -1; + } + + for (int i = 0; i < gf->n_nodes; i++) { + if (gf->nodes[i] == t) { + return i; + } + } + + fprintf(stderr, "%s: tensor %s not found in graph (should not happen)\n", __func__, name); + return -1; +} + +static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) { + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + MPI_Status status; UNUSED(status); + + const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +// TODO: there are many improvements that can be done to this implementation +void ggml_mpi_graph_compute_pre( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers) { + const int mpi_rank = ctx_mpi->rank; + const int mpi_size = ctx_mpi->size; + + struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens"); + if (inp_tokens == NULL) { + fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__); + return; + } + + struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0"); + if (inp0 == NULL) { + fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__); + return; + } + + GGML_ASSERT(inp0 == gf->nodes[0]); + + // distribute the compute graph into slices across the MPI nodes + // + // the main node (0) processes the last layers + the remainder of the compute graph + // and is responsible to pass the input tokens to the first node (1) + // + // node 1: [( 0) * n_per_node, ( 1) * n_per_node) + // node 2: [( 1) * n_per_node, ( 2) * n_per_node) + // ... + // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node) + // node 0: [(n-1) * n_per_node, n_nodes) + // + if (mpi_rank > 0) { + if (mpi_rank == 1) { + // the first node (1) receives the input tokens from the main node (0) + ggml_mpi_tensor_recv(inp_tokens, 0); + } else { + // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) + ggml_mpi_tensor_recv(inp0, mpi_rank - 1); + } + } else if (mpi_size > 1) { + // node 0 sends the input tokens to node 1 + ggml_mpi_tensor_send(inp_tokens, 1); + + // recv the output data from the last node + ggml_mpi_tensor_recv(inp0, mpi_size - 1); + } + + { + const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size; + + const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1; + + const int il0 = (mpi_idx + 0) * n_per_node; + const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node); + + char name_l0[GGML_MAX_NAME]; + char name_l1[GGML_MAX_NAME]; + + snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); + snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); + + const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); + const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes; + + if (idx_l0 < 0 || idx_l1 < 0) { + fprintf(stderr, "%s: layer input nodes not found\n", __func__); + return; + } + + // attach the input data to all nodes that need it + // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) + for (int i = idx_l0; i < idx_l1; i++) { + if (gf->nodes[i]->src0 == gf->nodes[idx_l0]) { + gf->nodes[i]->src0 = inp0; + } + if (gf->nodes[i]->src1 == gf->nodes[idx_l0]) { + gf->nodes[i]->src1 = inp0; + } + } + + // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph + for (int i = 1; i < idx_l1 - idx_l0; i++) { + gf->nodes[i] = gf->nodes[idx_l0 + i]; + gf->grads[i] = gf->grads[idx_l0 + i]; + } + + // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node + if (mpi_idx != 0) { + gf->nodes[0]->op = GGML_OP_NONE; + } + + gf->n_nodes = idx_l1 - idx_l0; + + //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); + } +} + +void ggml_mpi_graph_compute_post( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers) { + UNUSED(n_layers); + + const int mpi_rank = ctx_mpi->rank; + const int mpi_size = ctx_mpi->size; + + // send the output data to the next node + if (mpi_rank > 0) { + ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size); + } +} diff --git a/ggml-mpi.h b/ggml-mpi.h new file mode 100644 index 0000000..eda119d --- /dev/null +++ b/ggml-mpi.h @@ -0,0 +1,39 @@ +#pragma once + +struct ggml_context; +struct ggml_tensor; +struct ggml_cgraph; + +#ifdef __cplusplus +extern "C" { +#endif + +struct ggml_mpi_context; + +void ggml_mpi_backend_init(void); +void ggml_mpi_backend_free(void); + +struct ggml_mpi_context * ggml_mpi_init(void); +void ggml_mpi_free(struct ggml_mpi_context * ctx); + +int ggml_mpi_rank(struct ggml_mpi_context * ctx); + +void ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + int * n_tokens, + int * n_past, + int * n_threads); + +void ggml_mpi_graph_compute_pre( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers); + +void ggml_mpi_graph_compute_post( + struct ggml_mpi_context * ctx_mpi, + struct ggml_cgraph * gf, + int n_layers); + +#ifdef __cplusplus +} +#endif diff --git a/llama.cpp b/llama.cpp index a491f1c..ad7283f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -19,6 +19,9 @@ #ifdef GGML_USE_METAL #include "ggml-metal.h" #endif +#ifdef GGML_USE_MPI +#include "ggml-mpi.h" +#endif #ifdef GGML_USE_K_QUANTS #ifndef QK_K #ifdef GGML_QKK_64 @@ -352,6 +355,10 @@ struct llama_context { ggml_metal_context * ctx_metal = NULL; #endif +#ifdef GGML_USE_MPI + ggml_mpi_context * ctx_mpi = NULL; +#endif + int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -870,7 +877,7 @@ bool llama_mlock_supported() { return llama_mlock::SUPPORTED; } -void llama_init_backend(bool numa) { +void llama_backend_init(bool numa) { ggml_time_init(); // needed to initialize f16 tables @@ -883,6 +890,16 @@ void llama_init_backend(bool numa) { if (numa) { ggml_numa_init(); } + +#ifdef GGML_USE_MPI + ggml_mpi_backend_init(); +#endif +} + +void llama_backend_free() { +#ifdef GGML_USE_MPI + ggml_mpi_backend_free(); +#endif } int64_t llama_time_us() { @@ -1284,13 +1301,17 @@ static bool llama_eval_internal( llama_context & lctx, const llama_token * tokens, const float * embd, - const int n_tokens, - const int n_past, + int n_tokens, + int n_past, int n_threads, const char * cgraph_fname) { LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); +#ifdef GGML_USE_MPI + ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); +#endif + const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -1331,11 +1352,16 @@ static bool llama_eval_internal( struct ggml_tensor * inpL; if (tokens) { - struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); - ggml_set_name(embd, "embd"); - memcpy(embd->data, tokens, N*ggml_element_size(embd)); - inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); } @@ -1353,18 +1379,20 @@ static bool llama_eval_internal( offload_func_t offload_func_v = llama_nop; #ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers; - } + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers; + } #endif // GGML_USE_CUBLAS for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); + offload_func_t offload_func = llama_nop; #ifdef GGML_USE_CUBLAS @@ -1571,7 +1599,6 @@ static bool llama_eval_internal( // input for next layer inpL = cur; - } lctx.use_buf(ctx0, 0); @@ -1579,7 +1606,6 @@ static bool llama_eval_internal( // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; - // norm { cur = ggml_rms_norm(ctx0, inpL); @@ -1594,7 +1620,6 @@ static bool llama_eval_internal( embeddings = cur; } - // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); ggml_set_name(cur, "result_output"); @@ -1607,6 +1632,10 @@ static bool llama_eval_internal( // run the computation ggml_build_forward_expand(&gf, cur); +#if GGML_USE_MPI + ggml_mpi_graph_compute_pre(lctx.ctx_mpi, &gf, n_layer); +#endif + #ifdef GGML_USE_METAL if (lctx.ctx_metal && N == 1) { ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); @@ -1635,6 +1664,15 @@ static bool llama_eval_internal( ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); #endif +#if GGML_USE_MPI + ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer); +#endif + + // update kv token count + lctx.kv_self.n = n_past + N; + + struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1]; + if (cgraph_fname) { ggml_graph_export(&gf, cgraph_fname); } @@ -1650,23 +1688,17 @@ static bool llama_eval_internal( // ggml_graph_dump_dot(&gf, NULL, "llama.dot"); //} - //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(cur), sizeof(float)*n_vocab*N); - - // update kv token count - lctx.kv_self.n = n_past + N; - // extract logits { auto & logits_out = lctx.logits; if (lctx.logits_all) { logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(cur), sizeof(float)*n_vocab*N); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N); } else { // return result for just the last token logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(cur) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab); } } @@ -2697,6 +2729,18 @@ struct llama_context * llama_new_context_with_model( } #endif +#ifdef GGML_USE_MPI + ctx->ctx_mpi = ggml_mpi_init(); + + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos()); + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); + exit(1); + } +#endif + return ctx; } diff --git a/llama.h b/llama.h index c1e7dab..686463a 100644 --- a/llama.h +++ b/llama.h @@ -158,7 +158,9 @@ extern "C" { // Initialize the llama + ggml backend // If numa is true, use NUMA optimizations // Call once at the start of the program - LLAMA_API void llama_init_backend(bool numa); + LLAMA_API void llama_backend_init(bool numa); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(); LLAMA_API int64_t llama_time_us(); diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 20abe71..87fde16 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -31,6 +31,8 @@ int main(int argc, char **argv) { llama_model * model; llama_context * ctx; + llama_backend_init(false); + // load the vocab { auto lparams = llama_context_default_params(); @@ -97,5 +99,7 @@ int main(int argc, char **argv) { llama_free_model(model); llama_free(ctx); + llama_backend_free(); + return 0; }