mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
ggml : add AVX support based on AVX2 code (#1430)
This commit is contained in:
parent
601a033475
commit
60f8c361ca
1 changed files with 132 additions and 3 deletions
135
ggml.c
135
ggml.c
|
@ -580,7 +580,63 @@ static inline __m128i packNibbles( __m256i bytes )
|
|||
return _mm_packus_epi16( r0, r1 );
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
#elif defined(__AVX__)
|
||||
// spread 32 bits to 32 bytes { 0x00, 0xFF }
|
||||
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
|
||||
uint32_t x32;
|
||||
memcpy(&x32, x, sizeof(uint32_t));
|
||||
const __m128i shuf_maskl = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000);
|
||||
const __m128i shuf_maskh = _mm_set_epi64x(0x0303030303030303, 0x0202020202020202);
|
||||
__m128i bytesl = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskl);
|
||||
__m128i bytesh = _mm_shuffle_epi8(_mm_set1_epi32(x32), shuf_maskh);
|
||||
const __m128i bit_mask = _mm_set1_epi64x(0x7fbfdfeff7fbfdfe);
|
||||
bytesl = _mm_or_si128(bytesl, bit_mask);
|
||||
bytesh = _mm_or_si128(bytesh, bit_mask);
|
||||
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
|
||||
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
|
||||
return _mm256_set_m128i(bytesh, bytesl);
|
||||
}
|
||||
|
||||
// Unpack 32 4-bit fields into 32 bytes
|
||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
|
||||
{
|
||||
// Load 16 bytes from memory
|
||||
__m128i tmpl = _mm_loadu_si128((const __m128i *)rsi);
|
||||
__m128i tmph = _mm_srli_epi16(tmpl, 4);
|
||||
const __m128i lowMask = _mm_set1_epi8(0xF);
|
||||
tmpl = _mm_and_si128(lowMask, tmpl);
|
||||
tmph = _mm_and_si128(lowMask, tmph);
|
||||
return _mm256_set_m128i(tmph, tmpl);
|
||||
}
|
||||
|
||||
// add int16_t pairwise and return as float vector
|
||||
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
|
||||
const __m128i ones = _mm_set1_epi16(1);
|
||||
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
|
||||
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
|
||||
const __m256i summed_pairs = _mm256_set_m128i(summed_pairsh, summed_pairsl);
|
||||
return _mm256_cvtepi32_ps(summed_pairs);
|
||||
}
|
||||
|
||||
// multiply int8_t, add results pairwise twice and return as float vector
|
||||
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||
const __m128i xl = _mm256_castsi256_si128(x);
|
||||
const __m128i xh = _mm256_extractf128_si256(x, 1);
|
||||
const __m128i yl = _mm256_castsi256_si128(y);
|
||||
const __m128i yh = _mm256_extractf128_si256(y, 1);
|
||||
// Get absolute values of x vectors
|
||||
const __m128i axl = _mm_sign_epi8(xl, xl);
|
||||
const __m128i axh = _mm_sign_epi8(xh, xh);
|
||||
// Sign the values of the y vectors
|
||||
const __m128i syl = _mm_sign_epi8(yl, xl);
|
||||
const __m128i syh = _mm_sign_epi8(yh, xh);
|
||||
// Perform multiplication and create 16-bit values
|
||||
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
|
||||
const __m128i doth = _mm_maddubs_epi16(axh, syh);
|
||||
return sum_i16_pairs_float(doth, dotl);
|
||||
}
|
||||
|
||||
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||
{
|
||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||
|
@ -2355,7 +2411,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|||
}
|
||||
|
||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs;
|
||||
#elif defined(__AVX2__)
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
|
@ -2381,7 +2437,11 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
|
|||
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
|
||||
|
||||
// Accumulate d0*d1*x*y
|
||||
#if defined(__AVX2__)
|
||||
acc = _mm256_fmadd_ps( d0d1, xy, acc );
|
||||
#else
|
||||
acc = _mm256_add_ps( _mm256_mul_ps( d0d1, xy ), acc );
|
||||
#endif
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
|
@ -2592,6 +2652,37 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
|
|||
acc = _mm256_fmadd_ps(d, q, acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
#elif defined(__AVX__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
__m128i mask = _mm_set1_epi8((char)0xF0);
|
||||
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; i++) {
|
||||
/* Compute combined scale for the block */
|
||||
const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d));
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
||||
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
||||
bxhil = _mm_andnot_si128(bxhil, mask);
|
||||
bxhih = _mm_andnot_si128(bxhih, mask);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
||||
bxl = _mm_or_si128(bxl, bxhil);
|
||||
bxh = _mm_or_si128(bxh, bxhih);
|
||||
bx = _mm256_set_m128i(bxh, bxl);
|
||||
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||
|
||||
/* Multiply q with scale and accumulate */
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
#else
|
||||
// scalar
|
||||
|
@ -2820,6 +2911,40 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
|
|||
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
#elif defined(__AVX__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
__m128i mask = _mm_set1_epi8(0x10);
|
||||
|
||||
float summs = 0.0f;
|
||||
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const __m256 dx = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d));
|
||||
|
||||
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
|
||||
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
const __m256i bxhi = bytes_from_bits_32(x[i].qh);
|
||||
__m128i bxhil = _mm256_castsi256_si128(bxhi);
|
||||
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
|
||||
bxhil = _mm_and_si128(bxhil, mask);
|
||||
bxhih = _mm_and_si128(bxhih, mask);
|
||||
__m128i bxl = _mm256_castsi256_si128(bx);
|
||||
__m128i bxh = _mm256_extractf128_si256(bx, 1);
|
||||
bxl = _mm_or_si128(bxl, bxhil);
|
||||
bxh = _mm_or_si128(bxh, bxhih);
|
||||
bx = _mm256_set_m128i(bxh, bxl);
|
||||
|
||||
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||
|
||||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||
|
||||
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc) + summs;
|
||||
#else
|
||||
// scalar
|
||||
|
@ -2910,7 +3035,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
|
|||
}
|
||||
|
||||
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||
#elif defined(__AVX2__)
|
||||
#elif defined(__AVX2__) || defined(__AVX__)
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
|
@ -2924,7 +3049,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
|
|||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||
|
||||
// Multiply q with scale and accumulate
|
||||
#if defined(__AVX2__)
|
||||
acc = _mm256_fmadd_ps( d, q, acc );
|
||||
#else
|
||||
acc = _mm256_add_ps( _mm256_mul_ps( d, q ), acc );
|
||||
#endif
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
|
|
Loading…
Reference in a new issue