mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 23:29:44 +00:00
Add AVX2 implementation of dequantize_row_q4_1 (#505)
This commit is contained in:
parent
a316a425d0
commit
459e93cce0
1 changed files with 33 additions and 1 deletions
34
ggml.c
34
ggml.c
|
@ -783,7 +783,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
|
||||||
|
|
||||||
// Scale and store
|
// Scale and store
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
__m256 result = _mm256_mul_ps(vf[j], d_v);
|
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
||||||
_mm256_storeu_ps(y + i * QK + l + j*8, result);
|
_mm256_storeu_ps(y + i * QK + l + j*8, result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -879,6 +879,37 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
|
||||||
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
const uint8_t * restrict pm = ((const uint8_t *)x + 0*bs + sizeof(float));
|
||||||
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + 2*sizeof(float));
|
||||||
|
|
||||||
|
#if defined(__AVX2__)
|
||||||
|
for (int i = 0; i < nb; i++) {
|
||||||
|
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
|
||||||
|
const __m256 d_m = _mm256_broadcast_ss((const float *) (pm + i*bs));
|
||||||
|
|
||||||
|
const uint8_t * restrict pp = pb + i*bs;
|
||||||
|
|
||||||
|
for (int l = 0; l < QK; l += 32) {
|
||||||
|
// Load 32x4-bit integers into 32x8-bit integers
|
||||||
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
||||||
|
|
||||||
|
// Convert to 16-bit int
|
||||||
|
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
|
||||||
|
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
|
||||||
|
|
||||||
|
// Convert to 32-bit int -> float 32
|
||||||
|
const __m256 vf[4] = {
|
||||||
|
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
|
||||||
|
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
|
||||||
|
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
|
||||||
|
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
|
||||||
|
};
|
||||||
|
|
||||||
|
// Scale, add m and store
|
||||||
|
for (int j = 0; j < 4; j++) {
|
||||||
|
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
||||||
|
_mm256_storeu_ps(y + i * QK + l + j*8, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
const float d = *(const float *) (pd + i*bs);
|
const float d = *(const float *) (pd + i*bs);
|
||||||
const float m = *(const float *) (pm + i*bs);
|
const float m = *(const float *) (pm + i*bs);
|
||||||
|
@ -901,6 +932,7 @@ void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {
|
||||||
assert(!isnan(y[i*QK + l + 1]));
|
assert(!isnan(y[i*QK + l + 1]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
Loading…
Reference in a new issue