mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-14 00:59:43 +00:00
A better packNibbles
and mul_sum_i8_pairs_float
implementation using AVX512 (#1119)
This commit is contained in:
parent
0e018fe008
commit
c9e2c26f41
1 changed files with 12 additions and 0 deletions
12
ggml.c
12
ggml.c
|
@ -509,14 +509,25 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||||
const __m256i ax = _mm256_sign_epi8(x, x);
|
const __m256i ax = _mm256_sign_epi8(x, x);
|
||||||
// Sign the values of the y vectors
|
// Sign the values of the y vectors
|
||||||
const __m256i sy = _mm256_sign_epi8(y, x);
|
const __m256i sy = _mm256_sign_epi8(y, x);
|
||||||
|
#if __AVXVNNI__
|
||||||
|
const __m256i zero = _mm256_setzero_si256();
|
||||||
|
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
||||||
|
return _mm256_cvtepi32_ps(summed_pairs);
|
||||||
|
#else
|
||||||
// Perform multiplication and create 16-bit values
|
// Perform multiplication and create 16-bit values
|
||||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||||
return sum_i16_pairs_float(dot);
|
return sum_i16_pairs_float(dot);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline __m128i packNibbles( __m256i bytes )
|
static inline __m128i packNibbles( __m256i bytes )
|
||||||
{
|
{
|
||||||
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
|
||||||
|
#if __AVX512F__
|
||||||
|
const __m256i bytes_srli_4 = _mm256_srli_epi16(bytes, 4); // 0000_0000_abcd_0000
|
||||||
|
bytes = _mm256_or_si256(bytes, bytes_srli_4); // 0000_abcd_abcd_efgh
|
||||||
|
return _mm256_cvtepi16_epi8(bytes); // abcd_efgh
|
||||||
|
#else
|
||||||
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
|
const __m256i lowByte = _mm256_set1_epi16( 0xFF );
|
||||||
__m256i high = _mm256_andnot_si256( lowByte, bytes );
|
__m256i high = _mm256_andnot_si256( lowByte, bytes );
|
||||||
__m256i low = _mm256_and_si256( lowByte, bytes );
|
__m256i low = _mm256_and_si256( lowByte, bytes );
|
||||||
|
@ -527,6 +538,7 @@ static inline __m128i packNibbles( __m256i bytes )
|
||||||
__m128i r0 = _mm256_castsi256_si128( bytes );
|
__m128i r0 = _mm256_castsi256_si128( bytes );
|
||||||
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
__m128i r1 = _mm256_extracti128_si256( bytes, 1 );
|
||||||
return _mm_packus_epi16( r0, r1 );
|
return _mm_packus_epi16( r0, r1 );
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||||
|
|
Loading…
Reference in a new issue