mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-08 15:09:44 +00:00
ggml : fix rope args order + assert (#2054)
This commit is contained in:
parent
3973b25a64
commit
513f861953
4 changed files with 23 additions and 18 deletions
|
@ -1434,7 +1434,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
gf->perf_time_us = 0;
|
gf->perf_time_us = 0;
|
||||||
|
|
||||||
const auto & hparams = model->hparams;
|
const auto & hparams = model->hparams;
|
||||||
//const int n_ctx = hparams.n_ctx;
|
const int n_ctx = hparams.n_ctx;
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
const int n_embd = hparams.n_embd;
|
const int n_embd = hparams.n_embd;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
|
@ -1863,10 +1863,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
t12->grad = expand(gb, ggml_permute(ctx0, t15->grad, 0, 2, 3, 1)); assert_shape_4d(t12->grad, N, n_batch, n_embd/n_head, n_head);
|
||||||
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
t11->grad = expand(gb, ggml_reshape_2d(ctx0, ggml_cont(ctx0, t12->grad), N*n_batch, n_embd)); assert_shape_2d(t11->grad, N*n_batch, n_embd);
|
||||||
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
t10->grad = expand(gb, ggml_permute(ctx0, t14->grad, 0, 2, 1, 3)); assert_shape_4d(t10->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
t09->grad = expand(gb, ggml_rope_back(ctx0, t10->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t09->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
t08->grad = expand(gb, ggml_reshape_2d(ctx0, t09->grad, n_embd, N*n_batch)); assert_shape_2d(t08->grad, n_embd, N*n_batch);
|
||||||
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
t07->grad = expand(gb, ggml_permute(ctx0, t13->grad, 0, 2, 1, 3)); assert_shape_4d(t07->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
t06->grad = expand(gb, ggml_rope_back(ctx0, t07->grad, n_past, n_rot, rope_mode, n_ctx)); assert_shape_4d(t06->grad, n_embd/n_head, n_head, N, n_batch);
|
||||||
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
t05->grad = expand(gb, ggml_reshape_2d(ctx0, t06->grad, n_embd, N*n_batch)); assert_shape_2d(t05->grad, n_embd, N*n_batch);
|
||||||
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
t04->grad = expand(gb, ggml_add_inplace(ctx0,
|
||||||
ggml_add_inplace(ctx0,
|
ggml_add_inplace(ctx0,
|
||||||
|
|
24
ggml.c
24
ggml.c
|
@ -6956,9 +6956,9 @@ struct ggml_tensor * ggml_rope_impl(
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
|
int n_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale,
|
||||||
int n_ctx,
|
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
GGML_ASSERT(n_past >= 0);
|
GGML_ASSERT(n_past >= 0);
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
@ -6997,7 +6997,7 @@ struct ggml_tensor * ggml_rope(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx) {
|
int n_ctx) {
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, false);
|
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_inplace(
|
struct ggml_tensor * ggml_rope_inplace(
|
||||||
|
@ -7007,7 +7007,7 @@ struct ggml_tensor * ggml_rope_inplace(
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
int n_ctx) {
|
int n_ctx) {
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, 10000.0f, 1.0f, n_ctx, true);
|
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, 10000.0f, 1.0f, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_rope_custom_inplace(
|
struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
|
@ -7016,10 +7016,10 @@ struct ggml_tensor * ggml_rope_custom_inplace(
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
|
int n_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale) {
|
||||||
int n_ctx) {
|
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, n_ctx, freq_base, freq_scale, true);
|
||||||
return ggml_rope_impl(ctx, a, n_past, n_dims, mode, freq_base, freq_scale, n_ctx, true);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ggml_rope_back
|
// ggml_rope_back
|
||||||
|
@ -7029,7 +7029,8 @@ struct ggml_tensor * ggml_rope_back(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode) {
|
int mode,
|
||||||
|
int n_ctx) {
|
||||||
GGML_ASSERT(n_past >= 0);
|
GGML_ASSERT(n_past >= 0);
|
||||||
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
GGML_ASSERT((mode & 4) == 0 && "ggml_rope_back() for ChatGLM not implemented yet");
|
||||||
|
|
||||||
|
@ -7043,12 +7044,13 @@ struct ggml_tensor * ggml_rope_back(
|
||||||
|
|
||||||
ggml_scratch_save(ctx);
|
ggml_scratch_save(ctx);
|
||||||
|
|
||||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
|
||||||
ggml_set_name(b, "n_past, n_dims, mode");
|
ggml_set_name(b, "n_past, n_dims, mode");
|
||||||
|
|
||||||
((int32_t *) b->data)[0] = n_past;
|
((int32_t *) b->data)[0] = n_past;
|
||||||
((int32_t *) b->data)[1] = n_dims;
|
((int32_t *) b->data)[1] = n_dims;
|
||||||
((int32_t *) b->data)[2] = mode;
|
((int32_t *) b->data)[2] = mode;
|
||||||
|
((int32_t *) b->data)[3] = n_ctx;
|
||||||
|
|
||||||
ggml_scratch_load(ctx);
|
ggml_scratch_load(ctx);
|
||||||
|
|
||||||
|
@ -15740,13 +15742,15 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
const int n_past = ((int32_t *) src1->data)[0];
|
const int n_past = ((int32_t *) src1->data)[0];
|
||||||
const int n_dims = ((int32_t *) src1->data)[1];
|
const int n_dims = ((int32_t *) src1->data)[1];
|
||||||
const int mode = ((int32_t *) src1->data)[2];
|
const int mode = ((int32_t *) src1->data)[2];
|
||||||
|
const int n_ctx = ((int32_t *) src1->data)[3];
|
||||||
src0->grad = ggml_add_impl(ctx,
|
src0->grad = ggml_add_impl(ctx,
|
||||||
src0->grad,
|
src0->grad,
|
||||||
ggml_rope_back(ctx,
|
ggml_rope_back(ctx,
|
||||||
tensor->grad,
|
tensor->grad,
|
||||||
n_past,
|
n_past,
|
||||||
n_dims,
|
n_dims,
|
||||||
mode),
|
mode,
|
||||||
|
n_ctx),
|
||||||
inplace);
|
inplace);
|
||||||
}
|
}
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
|
@ -15757,7 +15761,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
{
|
{
|
||||||
if (src0->grad) {
|
if (src0->grad) {
|
||||||
assert(src1->type == GGML_TYPE_I32);
|
assert(src1->type == GGML_TYPE_I32);
|
||||||
assert(ggml_nelements(src1) == 3);
|
assert(ggml_nelements(src1) == 4);
|
||||||
const int n_past = ((int32_t *) src1->data)[0];
|
const int n_past = ((int32_t *) src1->data)[0];
|
||||||
const int n_dims = ((int32_t *) src1->data)[1];
|
const int n_dims = ((int32_t *) src1->data)[1];
|
||||||
const int mode = ((int32_t *) src1->data)[2];
|
const int mode = ((int32_t *) src1->data)[2];
|
||||||
|
|
7
ggml.h
7
ggml.h
|
@ -1128,9 +1128,9 @@ extern "C" {
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode,
|
int mode,
|
||||||
|
int n_ctx,
|
||||||
float freq_base,
|
float freq_base,
|
||||||
float freq_scale,
|
float freq_scale);
|
||||||
int n_ctx);
|
|
||||||
|
|
||||||
// rotary position embedding backward, i.e compute dx from dy
|
// rotary position embedding backward, i.e compute dx from dy
|
||||||
// a - dy
|
// a - dy
|
||||||
|
@ -1139,7 +1139,8 @@ extern "C" {
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_dims,
|
int n_dims,
|
||||||
int mode);
|
int mode,
|
||||||
|
int n_ctx);
|
||||||
|
|
||||||
// alibi position embedding
|
// alibi position embedding
|
||||||
// in-place, returns view(a)
|
// in-place, returns view(a)
|
||||||
|
|
|
@ -1452,11 +1452,11 @@ static bool llama_eval_internal(
|
||||||
offload_func_kq(tmpq);
|
offload_func_kq(tmpq);
|
||||||
ggml_set_name(tmpq, "tmpq");
|
ggml_set_name(tmpq, "tmpq");
|
||||||
|
|
||||||
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
|
struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
|
||||||
offload_func_kq(Kcur);
|
offload_func_kq(Kcur);
|
||||||
ggml_set_name(Kcur, "Kcur");
|
ggml_set_name(Kcur, "Kcur");
|
||||||
|
|
||||||
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, freq_base, freq_scale, 0);
|
struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd/n_head, n_head, N), n_past, n_rot, 0, 0, freq_base, freq_scale);
|
||||||
offload_func_kq(Qcur);
|
offload_func_kq(Qcur);
|
||||||
ggml_set_name(Qcur, "Qcur");
|
ggml_set_name(Qcur, "Qcur");
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue