mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
ggml: add names to tensors (#1268)
* ggml: add names to tensors * minor improvements to dot file formatting
This commit is contained in:
parent
f4cef87edf
commit
2d099e5193
3 changed files with 68 additions and 20 deletions
56
ggml.c
56
ggml.c
|
@ -4541,6 +4541,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
|
||||||
/*.perf_cycles =*/ 0,
|
/*.perf_cycles =*/ 0,
|
||||||
/*.perf_time_us =*/ 0,
|
/*.perf_time_us =*/ 0,
|
||||||
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
/*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
|
||||||
|
/*.name =*/ { 0 },
|
||||||
/*.pad =*/ { 0 },
|
/*.pad =*/ { 0 },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -4895,6 +4896,15 @@ float * ggml_get_data_f32(const struct ggml_tensor * tensor) {
|
||||||
return (float *)(tensor->data);
|
return (float *)(tensor->data);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
||||||
|
return tensor->name;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_set_name(struct ggml_tensor * tensor, const char * name) {
|
||||||
|
strncpy(tensor->name, name, sizeof(tensor->name));
|
||||||
|
tensor->name[sizeof(tensor->name) - 1] = '\0';
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_view_tensor(
|
struct ggml_tensor * ggml_view_tensor(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
const struct ggml_tensor * src) {
|
const struct ggml_tensor * src) {
|
||||||
|
@ -5994,6 +6004,7 @@ struct ggml_tensor * ggml_diag_mask_inf(
|
||||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||||
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
struct ggml_tensor * b = ggml_new_i32(ctx, n_past);
|
||||||
|
ggml_set_name(b, "n_past");
|
||||||
|
|
||||||
result->op = GGML_OP_DIAG_MASK_INF;
|
result->op = GGML_OP_DIAG_MASK_INF;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -6051,6 +6062,7 @@ struct ggml_tensor * ggml_rope(
|
||||||
((int32_t *) b->data)[0] = n_past;
|
((int32_t *) b->data)[0] = n_past;
|
||||||
((int32_t *) b->data)[1] = n_dims;
|
((int32_t *) b->data)[1] = n_dims;
|
||||||
((int32_t *) b->data)[2] = mode;
|
((int32_t *) b->data)[2] = mode;
|
||||||
|
ggml_set_name(b, "n_past, n_dims, mode");
|
||||||
|
|
||||||
result->op = GGML_OP_ROPE;
|
result->op = GGML_OP_ROPE;
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
@ -12118,10 +12130,16 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph
|
||||||
snprintf(color, sizeof(color), "white");
|
snprintf(color, sizeof(color), "white");
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(fp, " \"%p\" [ \
|
fprintf(fp, " \"%p\" [ "
|
||||||
style = filled; fillcolor = %s; shape = record; \
|
"style = filled; fillcolor = %s; shape = record; "
|
||||||
label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
|
"label=\"",
|
||||||
(void *) node, color,
|
(void *) node, color);
|
||||||
|
|
||||||
|
if (strlen(node->name) > 0) {
|
||||||
|
fprintf(fp, "%s |", node->name);
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(fp, "%d [%" PRId64 ", %" PRId64 "] | <x>%s",
|
||||||
i, node->ne[0], node->ne[1],
|
i, node->ne[0], node->ne[1],
|
||||||
GGML_OP_SYMBOL[node->op]);
|
GGML_OP_SYMBOL[node->op]);
|
||||||
|
|
||||||
|
@ -12137,18 +12155,26 @@ label=\"%d [%" PRId64 ", %" PRId64 "] | <x>%s",
|
||||||
|
|
||||||
snprintf(color, sizeof(color), "pink");
|
snprintf(color, sizeof(color), "pink");
|
||||||
|
|
||||||
if (ggml_nelements(node) == 1) {
|
fprintf(fp, " \"%p\" [ "
|
||||||
fprintf(fp, " \"%p\" [ \
|
"style = filled; fillcolor = %s; shape = record; "
|
||||||
style = filled; fillcolor = %s; shape = record; \
|
"label=\"<x>",
|
||||||
label=\"<x>%.1e\"; ]\n",
|
(void *) node, color);
|
||||||
(void *) node, color, (double)ggml_get_f32_1d(node, 0));
|
|
||||||
} else {
|
if (strlen(node->name) > 0) {
|
||||||
fprintf(fp, " \"%p\" [ \
|
fprintf(fp, "%s | ", node->name);
|
||||||
style = filled; fillcolor = %s; shape = record; \
|
|
||||||
label=\"<x>CONST %d [%" PRId64 ", %" PRId64 "]\"; ]\n",
|
|
||||||
(void *) node, color,
|
|
||||||
i, node->ne[0], node->ne[1]);
|
|
||||||
}
|
}
|
||||||
|
if (ggml_nelements(node) == 1) {
|
||||||
|
if (node->type == GGML_TYPE_I8 || node->type == GGML_TYPE_I16 || node->type == GGML_TYPE_I32) {
|
||||||
|
fprintf(fp, "%d", ggml_get_i32_1d(node, 0));
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
fprintf(fp, "%.1e", (double)ggml_get_f32_1d(node, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
fprintf(fp, "CONST %d [%" PRId64 ", %" PRId64 "]", i, node->ne[0], node->ne[1]);
|
||||||
|
}
|
||||||
|
fprintf(fp, "\"; ]\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < gb->n_nodes; i++) {
|
for (int i = 0; i < gb->n_nodes; i++) {
|
||||||
|
|
8
ggml.h
8
ggml.h
|
@ -350,7 +350,10 @@ extern "C" {
|
||||||
int64_t perf_time_us;
|
int64_t perf_time_us;
|
||||||
|
|
||||||
void * data;
|
void * data;
|
||||||
char padding[8];
|
|
||||||
|
char name[32];
|
||||||
|
|
||||||
|
char padding[8]; // TODO: remove and add padding to name?
|
||||||
};
|
};
|
||||||
|
|
||||||
// computation graph
|
// computation graph
|
||||||
|
@ -473,6 +476,9 @@ extern "C" {
|
||||||
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
||||||
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
GGML_API const char * ggml_get_name(const struct ggml_tensor * tensor);
|
||||||
|
GGML_API void ggml_set_name(struct ggml_tensor * tensor, const char * name);
|
||||||
|
|
||||||
//
|
//
|
||||||
// operations on tensors with backpropagation
|
// operations on tensors with backpropagation
|
||||||
//
|
//
|
||||||
|
|
24
llama.cpp
24
llama.cpp
|
@ -659,6 +659,7 @@ struct llama_model_loader {
|
||||||
LLAMA_ASSERT(lt.ne.size() == 1);
|
LLAMA_ASSERT(lt.ne.size() == 1);
|
||||||
tensor = ggml_new_tensor_1d(ggml_ctx, lt.type, lt.ne.at(0));
|
tensor = ggml_new_tensor_1d(ggml_ctx, lt.type, lt.ne.at(0));
|
||||||
}
|
}
|
||||||
|
ggml_set_name(tensor, lt.name.c_str());
|
||||||
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
|
LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
|
||||||
lt.ggml_tensor = tensor;
|
lt.ggml_tensor = tensor;
|
||||||
num_ggml_tensors_created++;
|
num_ggml_tensors_created++;
|
||||||
|
@ -798,6 +799,8 @@ static bool kv_cache_init(
|
||||||
|
|
||||||
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||||
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||||
|
ggml_set_name(cache.k, "cache_k");
|
||||||
|
ggml_set_name(cache.v, "cache_v");
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -1084,6 +1087,7 @@ static bool llama_eval_internal(
|
||||||
gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
|
gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads;
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
|
ggml_set_name(embd, "embd");
|
||||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||||
|
|
||||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd);
|
||||||
|
@ -1110,6 +1114,8 @@ static bool llama_eval_internal(
|
||||||
// compute Q and K and RoPE them
|
// compute Q and K and RoPE them
|
||||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0);
|
||||||
|
ggml_set_name(Qcur, "Qcur");
|
||||||
|
ggml_set_name(Kcur, "Kcur");
|
||||||
|
|
||||||
// store key and value to memory
|
// store key and value to memory
|
||||||
{
|
{
|
||||||
|
@ -1130,6 +1136,7 @@ static bool llama_eval_internal(
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
Qcur,
|
Qcur,
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
ggml_set_name(Q, "Q");
|
||||||
|
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
|
@ -1137,21 +1144,26 @@ static bool llama_eval_internal(
|
||||||
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
ggml_view_1d(ctx0, kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.k)*n_embd),
|
||||||
n_embd/n_head, n_head, n_past + N),
|
n_embd/n_head, n_head, n_past + N),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
ggml_set_name(K, "K");
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||||
|
ggml_set_name(KQ, "KQ");
|
||||||
|
|
||||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||||
struct ggml_tensor * KQ_scaled =
|
struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head));
|
||||||
ggml_scale(ctx0,
|
ggml_set_name(KQ_scale, "1/sqrt(n_embd/n_head)");
|
||||||
KQ,
|
|
||||||
ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)));
|
struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale);
|
||||||
|
ggml_set_name(KQ_scaled, "KQ_scaled");
|
||||||
|
|
||||||
// KQ_masked = mask_past(KQ_scaled)
|
// KQ_masked = mask_past(KQ_scaled)
|
||||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
|
||||||
|
ggml_set_name(KQ_masked, "KQ_masked");
|
||||||
|
|
||||||
// KQ = soft_max(KQ_masked)
|
// KQ = soft_max(KQ_masked)
|
||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||||
|
ggml_set_name(KQ_soft_max, "KQ_soft_max");
|
||||||
|
|
||||||
// split cached V into n_head heads
|
// split cached V into n_head heads
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
|
@ -1160,9 +1172,11 @@ static bool llama_eval_internal(
|
||||||
n_ctx*ggml_element_size(kv_self.v),
|
n_ctx*ggml_element_size(kv_self.v),
|
||||||
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
n_ctx*ggml_element_size(kv_self.v)*n_embd/n_head,
|
||||||
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
il*n_ctx*ggml_element_size(kv_self.v)*n_embd);
|
||||||
|
ggml_set_name(V, "V");
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
|
ggml_set_name(KQV, "KQV");
|
||||||
#else
|
#else
|
||||||
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
|
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
|
||||||
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
|
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
|
||||||
|
@ -1173,11 +1187,13 @@ static bool llama_eval_internal(
|
||||||
|
|
||||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
ggml_set_name(KQV_merged, "KQV_merged");
|
||||||
|
|
||||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||||
cur = ggml_cpy(ctx0,
|
cur = ggml_cpy(ctx0,
|
||||||
KQV_merged,
|
KQV_merged,
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||||
|
ggml_set_name(cur, "KQV_merged_contiguous");
|
||||||
|
|
||||||
// projection (no bias)
|
// projection (no bias)
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
|
|
Loading…
Reference in a new issue