mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
llama : only copy used KV cache in get / set state (#1272)
* llama : only copy used KV cache in get / set state * switch to ggml for copying k, v * avoid designated initializers
This commit is contained in:
parent
2485d7a4d3
commit
e216aa0463
2 changed files with 80 additions and 23 deletions
98
llama.cpp
98
llama.cpp
|
@ -1285,6 +1285,9 @@ static bool llama_eval_internal(
|
||||||
//embd_w.resize(n_vocab*N);
|
//embd_w.resize(n_vocab*N);
|
||||||
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
|
||||||
|
|
||||||
|
// update kv token count
|
||||||
|
lctx.model.kv_self.n = n_past + N;
|
||||||
|
|
||||||
// extract logits
|
// extract logits
|
||||||
{
|
{
|
||||||
auto & logits_out = lctx.logits;
|
auto & logits_out = lctx.logits;
|
||||||
|
@ -2401,7 +2404,7 @@ void llama_set_rng_seed(struct llama_context * ctx, int seed) {
|
||||||
ctx->rng.seed(seed);
|
ctx->rng.seed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the size of the state
|
// Returns the *maximum* size of the state
|
||||||
size_t llama_get_state_size(const struct llama_context * ctx) {
|
size_t llama_get_state_size(const struct llama_context * ctx) {
|
||||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||||
|
@ -2480,21 +2483,51 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
||||||
|
|
||||||
// copy kv cache
|
// copy kv cache
|
||||||
{
|
{
|
||||||
const size_t kv_size = ctx->model.kv_self.buf.size;
|
const auto & kv_self = ctx->model.kv_self;
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
const int n_layer = hparams.n_layer;
|
||||||
|
const int n_embd = hparams.n_embd;
|
||||||
|
const int n_ctx = hparams.n_ctx;
|
||||||
|
|
||||||
|
const size_t kv_size = kv_self.buf.size;
|
||||||
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
const int kv_ntok = llama_get_kv_cache_token_count(ctx);
|
||||||
|
|
||||||
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
||||||
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
||||||
|
|
||||||
if (kv_size) {
|
if (kv_size) {
|
||||||
memcpy(out, ctx->model.kv_self.buf.addr, kv_size); out += kv_size;
|
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||||
|
char buffer[4096];
|
||||||
|
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
||||||
|
ggml_cgraph gf{};
|
||||||
|
gf.n_threads = 1;
|
||||||
|
|
||||||
|
ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||||
|
kout3d->data = out;
|
||||||
|
out += ggml_nbytes(kout3d);
|
||||||
|
|
||||||
|
ggml_tensor * vout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
||||||
|
vout3d->data = out;
|
||||||
|
out += ggml_nbytes(vout3d);
|
||||||
|
|
||||||
|
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
||||||
|
n_embd, kv_ntok, n_layer,
|
||||||
|
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
|
||||||
|
|
||||||
|
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
|
||||||
|
kv_ntok, n_embd, n_layer,
|
||||||
|
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
|
||||||
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
|
||||||
|
ggml_graph_compute(cpy_ctx, &gf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t written = out - dest;
|
const size_t written = out - dest;
|
||||||
const size_t expected = llama_get_state_size(ctx);
|
const size_t max_size = llama_get_state_size(ctx);
|
||||||
|
|
||||||
LLAMA_ASSERT(written == expected);
|
LLAMA_ASSERT(written <= max_size);
|
||||||
|
|
||||||
return written;
|
return written;
|
||||||
}
|
}
|
||||||
|
@ -2552,6 +2585,12 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
|
|
||||||
// set kv cache
|
// set kv cache
|
||||||
{
|
{
|
||||||
|
const auto & kv_self = ctx->model.kv_self;
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
const int n_layer = hparams.n_layer;
|
||||||
|
const int n_embd = hparams.n_embd;
|
||||||
|
const int n_ctx = hparams.n_ctx;
|
||||||
|
|
||||||
size_t kv_size;
|
size_t kv_size;
|
||||||
int kv_ntok;
|
int kv_ntok;
|
||||||
|
|
||||||
|
@ -2559,25 +2598,42 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
||||||
|
|
||||||
if (kv_size) {
|
if (kv_size) {
|
||||||
LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size);
|
LLAMA_ASSERT(kv_self.buf.size == kv_size);
|
||||||
|
|
||||||
void * k_data = ctx->model.kv_self.k->data; // remember data pointers
|
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||||
void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
char buffer[4096];
|
||||||
|
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
||||||
|
ggml_cgraph gf{};
|
||||||
|
gf.n_threads = 1;
|
||||||
|
|
||||||
memcpy(ctx->model.kv_self.buf.addr, in, kv_size); in += kv_size;
|
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||||
|
kin3d->data = (void *) in;
|
||||||
|
in += ggml_nbytes(kin3d);
|
||||||
|
|
||||||
ctx->model.kv_self.k->data = k_data; // restore correct data pointers
|
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
||||||
ctx->model.kv_self.v->data = v_data;
|
vin3d->data = (void *) in;
|
||||||
|
in += ggml_nbytes(vin3d);
|
||||||
|
|
||||||
|
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
||||||
|
n_embd, kv_ntok, n_layer,
|
||||||
|
elt_size*n_embd, elt_size*n_embd*n_ctx, 0);
|
||||||
|
|
||||||
|
ggml_tensor * v3d = ggml_view_3d(cpy_ctx, kv_self.v,
|
||||||
|
kv_ntok, n_embd, n_layer,
|
||||||
|
elt_size*n_ctx, elt_size*n_ctx*n_embd, 0);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
|
||||||
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
|
||||||
|
ggml_graph_compute(cpy_ctx, &gf);
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->model.kv_self.n = kv_ntok;
|
ctx->model.kv_self.n = kv_ntok;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t nread = in - src;
|
const size_t nread = in - src;
|
||||||
const size_t expected = llama_get_state_size(ctx);
|
const size_t max_size = llama_get_state_size(ctx);
|
||||||
|
|
||||||
LLAMA_ASSERT(nread == expected);
|
LLAMA_ASSERT(nread <= max_size);
|
||||||
|
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
|
@ -2620,14 +2676,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
||||||
// restore the context state
|
// restore the context state
|
||||||
{
|
{
|
||||||
const size_t n_state_size_cur = file.size - file.tell();
|
const size_t n_state_size_cur = file.size - file.tell();
|
||||||
const size_t n_state_size_exp = llama_get_state_size(ctx);
|
const size_t n_state_size_max = llama_get_state_size(ctx);
|
||||||
|
|
||||||
if (n_state_size_cur != n_state_size_exp) {
|
if (n_state_size_cur > n_state_size_max) {
|
||||||
fprintf(stderr, "%s : the state size in session file didn't match! expected %zu, got %zu\n", __func__, n_state_size_exp, n_state_size_cur);
|
fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint8_t> state_data(n_state_size_cur);
|
std::vector<uint8_t> state_data(n_state_size_max);
|
||||||
file.read_raw(state_data.data(), n_state_size_cur);
|
file.read_raw(state_data.data(), n_state_size_cur);
|
||||||
|
|
||||||
llama_set_state_data(ctx, state_data.data());
|
llama_set_state_data(ctx, state_data.data());
|
||||||
|
@ -2650,12 +2706,12 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
||||||
|
|
||||||
// save the context state
|
// save the context state
|
||||||
{
|
{
|
||||||
const size_t n_state_size = llama_get_state_size(ctx);
|
const size_t n_state_size_max = llama_get_state_size(ctx);
|
||||||
|
|
||||||
std::vector<uint8_t> state_data(n_state_size);
|
std::vector<uint8_t> state_data(n_state_size_max);
|
||||||
llama_copy_state_data(ctx, state_data.data());
|
const size_t n_state_size_cur = llama_copy_state_data(ctx, state_data.data());
|
||||||
|
|
||||||
file.write_raw(state_data.data(), n_state_size);
|
file.write_raw(state_data.data(), n_state_size_cur);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
|
5
llama.h
5
llama.h
|
@ -23,7 +23,7 @@
|
||||||
#define LLAMA_FILE_MAGIC 'ggjt'
|
#define LLAMA_FILE_MAGIC 'ggjt'
|
||||||
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
|
#define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml'
|
||||||
#define LLAMA_SESSION_MAGIC 'ggsn'
|
#define LLAMA_SESSION_MAGIC 'ggsn'
|
||||||
#define LLAMA_SESSION_VERSION 0
|
#define LLAMA_SESSION_VERSION 1
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -127,7 +127,8 @@ extern "C" {
|
||||||
// Sets the current rng seed.
|
// Sets the current rng seed.
|
||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, int seed);
|
||||||
|
|
||||||
// Returns the size in bytes of the state (rng, logits, embedding and kv_cache)
|
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
|
// and kv_cache) - will often be smaller after compacting tokens
|
||||||
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Copies the state to the specified destination address.
|
// Copies the state to the specified destination address.
|
||||||
|
|
Loading…
Reference in a new issue