Speedup the AVX-512 implementation of ggml_vec_dot_q4_0() (#933)

This commit is contained in:
Ivan Komarov 2023-04-17 15:10:57 +02:00 committed by GitHub
parent 47f61aaa5f
commit f266259ad9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 235 additions and 44 deletions

View file

@ -55,6 +55,8 @@ option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer"
option(LLAMA_AVX "llama: enable AVX" ON)
option(LLAMA_AVX2 "llama: enable AVX2" ON)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_FMA "llama: enable FMA" ON)
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
@ -220,6 +222,16 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
if (MSVC)
if (LLAMA_AVX512)
add_compile_options(/arch:AVX512)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
add_compile_definitions(__AVX512VBMI__)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_definitions(__AVX512VNNI__)
endif()
elseif (LLAMA_AVX2)
add_compile_options(/arch:AVX2)
elseif (LLAMA_AVX)
@ -240,9 +252,13 @@ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
endif()
if (LLAMA_AVX512)
add_compile_options(-mavx512f)
# add_compile_options(-mavx512cd)
# add_compile_options(-mavx512dq)
# add_compile_options(-mavx512bw)
add_compile_options(-mavx512bw)
endif()
if (LLAMA_AVX512_VBMI)
add_compile_options(-mavx512vbmi)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_options(-mavx512vnni)
endif()
endif()
else()

229
ggml.c
View file

@ -1977,33 +1977,187 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
}
#if __AVX512F__ && QK4_0 == 32
static inline __m512 dot_q4_0_oneblock_avx512(
static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
// The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// | :. =_ () [] <> () Zz Yy|
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
// |Xx Ww Vv Uu Tt Ss Rr Qq Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa |
// +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
//
// Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
// We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
// Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
// Bytes 40..63 are masked when loading the data, so they are zeroed out.
#ifdef __AVX512VBMI__
const __m512i byte_perm = _mm512_set_epi8(
39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
11, 10, 11, 10, 9, 8, 9, 8, 7, 6, 7, 6, 5, 4, 5, 4
);
const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
// After applying VPERMB, `permuted` looks like this:
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
#else
const __m512i word_perm = _mm512_set_epi16(
19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2
);
const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
// This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
// VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
// Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
#endif
// Shift every odd-numbered 16-bit group to the right by 4 bits.
const __mmask32 shift_mask = 0xaaaaaaaa;
const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
// After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | : .= :. =_ ( )[ () [] < >( <> () Z zY Zz Yy X xW Xx Ww V vU Vv Uu T tS Tt Ss R rQ Rr Qq
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | P pO Pp Oo N nM Nn Mm L lK Ll Kk J jI Jj Ii H hG Hh Gg F fE Ff Ee D dC Dd Cc B bA Bb Aa|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// Now we just need to zero out the higher nibble in each byte, and we're done.
const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
return _mm512_and_si512( low_nibble_mask, shifted );
// The final result looks like this:
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | : = . _ ( [ ) ] < ( > ) Z Y z y X W x w V U v u T S t s R Q r q|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
// | P O p o N M n m L K l k J I j i H G h g F E f e D C d c B A b a|
// +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
}
static inline __m512 dot_q4_0_twoblocks_avx512(
__m512 acc,
const block_q4_0 * restrict x,
const block_q4_0 * restrict y,
int i
) {
// Compute combined scale for the block
__m512 d = _mm512_set1_ps( x[i].d * y[i].d );
// A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
// can potentially be unaddressable, so we make sure to mask them out before the load, even though
// we don't use them at all. This might hurt the performance slightly, since the compiler is forced
// to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
const __mmask8 load_mask = 0x1f;
const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
__m256i bx = bytesFromNibbles( x[i].qs );
__m256i by = bytesFromNibbles( y[i].qs );
// We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// blocks_0_float
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | | | | | | | xx | xx | xx | xx | B | xx | xx | xx | xx | A |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// blocks_1_float
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | | | | | | | xx | xx | xx | xx | D | xx | xx | xx | xx | C |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
// We absolutely shouldn't touch the floats marked with `xx`: they contain some
// random data, which might very well underflow. At least on Intel, this leads
// to a huge penalty that can't be ignored (easily 100x or more) unless you
// compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
// (and ggml can't assume that you do)...
const __mmask16 scale_mul_mask = 0x21;
#ifdef __clang__
// ...however, clang decides to optimize the multiplication mask away:
// https://godbolt.org/z/P8PqdsfvW
// gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
__m512i scales;
__asm__(
"vmulps %1, %2, %0%{%3%}"
: "=v" ( scales )
: "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
);
#else
const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
#endif
const __m512i scale_perm = _mm512_set_epi32(
5, 5, 5, 5, 5, 5, 5, 5,
0, 0, 0, 0, 0, 0, 0, 0
);
const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
// After VMULPS and VPERMPS, `permuted_scales` looks like this:
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
// +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
by = _mm256_sub_epi8( by, off );
const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
// Sign-extend 16 signed bytes into int16_t
__m512i x32 = _mm512_cvtepi8_epi16( bx );
__m512i y32 = _mm512_cvtepi8_epi16( by );
// Compute products of int16_t integers, add pairwise
__m512i i64 = _mm512_madd_epi16( x32, y32 );
// Now we want to compute dot products of 4-element byte vectors and store them in
// 32-bit integers. That is (only one 4-element vector is shown for clarity):
// +----+----+----+----+
// ... | 03 | 02 | 01 | 00 |
// +----+----+----+----+
// bytes_0
// +----+----+----+----+
// ... | D | C | B | A |
// +----+----+----+----+
// bytes_1
// +----+----+----+----+
// ... | H | G | F | E |
// +----+----+----+----+
// final_res_int
// +----+----+----+----+
// ... | A*E+B*F+C*G+D*H |
// +----+----+----+----+
const __m512i plus_8 = _mm512_set1_epi8( 8 );
const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
// Convert int32_t to float
__m512 p = _mm512_cvtepi32_ps( i64 );
// Apply the scale, and accumulate
return _mm512_fmadd_ps( d, p, acc );
#ifdef __AVX512VNNI__
// We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
// the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
// from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
// we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
// which means we only need 2 instructions.
const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
const __m512i minus_8 = _mm512_set1_epi8( -8 );
const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
#else
// As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
// It has the same catch as VPDPBUSDS: the left operand should be unsigned.
// This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
// ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
const __m512i one = _mm512_set1_epi16( 1 );
const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
const __m512i final_res_int = _mm512_madd_epi16( diff, one );
#endif
// Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
}
#endif
@ -2135,25 +2289,26 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
__m512 acc0 = _mm512_setzero_ps();
__m512 acc1 = _mm512_setzero_ps();
const int superblock_size = 8;
const int superblock_size = 16;
const int superblock_count = nb / superblock_size;
for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
int i = superblock_ix * superblock_size;
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+0 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+1 );
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+2 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+3 );
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+4 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+5 );
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i+6 );
acc1 = dot_q4_0_oneblock_avx512( acc1, x, y, i+7 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
}
// Remainders
for (int i = superblock_count * superblock_size; i < nb; ++i) {
acc0 = dot_q4_0_oneblock_avx512( acc0, x, y, i );
for (int i = superblock_count * superblock_size; i < nb; i += 2) {
acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
}
// Horizontal sum of all lanes of the accumulator
@ -11303,6 +11458,22 @@ int ggml_cpu_has_avx512(void) {
#endif
}
int ggml_cpu_has_avx512_vbmi(void) {
#if defined(__AVX512VBMI__)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_avx512_vnni(void) {
#if defined(__AVX512VNNI__)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_fma(void) {
#if defined(__FMA__)
return 1;

2
ggml.h
View file

@ -808,6 +808,8 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
int ggml_cpu_has_avx(void);
int ggml_cpu_has_avx2(void);
int ggml_cpu_has_avx512(void);
int ggml_cpu_has_avx512_vbmi(void);
int ggml_cpu_has_avx512_vnni(void);
int ggml_cpu_has_fma(void);
int ggml_cpu_has_neon(void);
int ggml_cpu_has_arm_fma(void);

View file

@ -1915,18 +1915,20 @@ const char * llama_print_system_info(void) {
static std::string s;
s = "";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | ";
s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | ";
s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | ";
s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | ";
s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | ";
s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | ";
s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | ";
s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | ";
s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | ";
s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | ";
s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | ";
s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | ";
return s.c_str();
}