From 44f906e8537fcec965e312d621c80556d6aa9bec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 6 Jun 2023 20:16:57 +0300 Subject: [PATCH] metal : add f16 support --- ggml-metal.m | 23 +++++++++++++---------- ggml-metal.metal | 16 ++++++++++++++++ llama.cpp | 3 ++- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index d721ac6..0953af6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -47,10 +47,11 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(rms_norm); - GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(rope); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); @@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) { GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(rms_norm); - GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(rope); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); @@ -498,6 +500,14 @@ void ggml_metal_graph_compute( // use custom matrix x vector kernel switch (src0t) { + case GGML_TYPE_F16: + { + GGML_ASSERT(ne02 == ne12); + + nth0 = 64; + nth1 = 1; + [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; + } break; case GGML_TYPE_Q4_0: { GGML_ASSERT(ne02 == 1); @@ -507,14 +517,6 @@ void ggml_metal_graph_compute( nth1 = 4; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; - case GGML_TYPE_F16: - { - GGML_ASSERT(ne02 == ne12); - - nth0 = 32; - nth1 = 1; - [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; - } break; default: GGML_ASSERT(false && "not implemented"); }; @@ -551,6 +553,7 @@ void ggml_metal_graph_compute( } switch (src0->type) { + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; default: GGML_ASSERT(false && "not implemented"); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 4bedc8e..a359beb 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -169,6 +169,22 @@ kernel void kernel_diag_mask_inf( } } +kernel void kernel_get_rows_f16( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint tpig[[thread_position_in_grid]]) { + const int i = tpig; + const int r = ((device int32_t *) src1)[i]; + + for (int j = 0; j < ne00; j++) { + dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j]; + } +} + kernel void kernel_get_rows_q4_0( device const void * src0, device const int * src1, diff --git a/llama.cpp b/llama.cpp index 70341d0..73f6860 100644 --- a/llama.cpp +++ b/llama.cpp @@ -961,7 +961,6 @@ static void llama_model_load_internal( model.hparams = ml->file_loaders.at(0)->hparams; llama_file_version file_version = ml->file_loaders.at(0)->file_version; auto & hparams = model.hparams; - uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; { switch (hparams.n_layer) { @@ -975,6 +974,8 @@ static void llama_model_load_internal( hparams.n_ctx = n_ctx; } + const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; + { fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version)); fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);