mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 23:29:44 +00:00
ggml : fix quantize_row_q4_1() ARM_NEON (close #876)
This commit is contained in:
parent
180b693a47
commit
684da25926
1 changed files with 4 additions and 6 deletions
10
ggml.c
10
ggml.c
|
@ -599,10 +599,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
||||||
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]);
|
||||||
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]);
|
||||||
|
|
||||||
// absolute max
|
const float amax = vmaxvq_f32(amaxv[0]);
|
||||||
const float amax = MAX(
|
|
||||||
MAX(vgetq_lane_f32(amaxv[0], 0), vgetq_lane_f32(amaxv[0], 1)),
|
|
||||||
MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3)));
|
|
||||||
|
|
||||||
const float d = amax / ((1 << 3) - 1);
|
const float d = amax / ((1 << 3) - 1);
|
||||||
const float id = d ? 1.0f/d : 0.0f;
|
const float id = d ? 1.0f/d : 0.0f;
|
||||||
|
@ -924,7 +921,7 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
||||||
float32x4_t minv[8];
|
float32x4_t minv[8];
|
||||||
float32x4_t maxv[8];
|
float32x4_t maxv[8];
|
||||||
|
|
||||||
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l);
|
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK + 4*l);
|
||||||
|
|
||||||
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
|
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
|
||||||
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
|
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
|
||||||
|
@ -947,7 +944,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
||||||
|
|
||||||
for (int l = 0; l < 8; l++) {
|
for (int l = 0; l < 8; l++) {
|
||||||
const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
|
const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
|
||||||
const int32x4_t vi = vcvtq_s32_f32(v);
|
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
|
||||||
|
const int32x4_t vi = vcvtq_s32_f32(vf);
|
||||||
|
|
||||||
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
|
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
|
||||||
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
|
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
|
||||||
|
|
Loading…
Reference in a new issue