mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
metal : add f16 support
This commit is contained in:
parent
d5b111f53d
commit
44f906e853
3 changed files with 31 additions and 11 deletions
23
ggml-metal.m
23
ggml-metal.m
|
@ -47,10 +47,11 @@ struct ggml_metal_context {
|
||||||
GGML_METAL_DECL_KERNEL(relu);
|
GGML_METAL_DECL_KERNEL(relu);
|
||||||
GGML_METAL_DECL_KERNEL(soft_max);
|
GGML_METAL_DECL_KERNEL(soft_max);
|
||||||
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
GGML_METAL_DECL_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_DECL_KERNEL(rms_norm);
|
GGML_METAL_DECL_KERNEL(rms_norm);
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
|
||||||
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
||||||
|
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_DECL_KERNEL(rope);
|
GGML_METAL_DECL_KERNEL(rope);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
|
||||||
|
@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
||||||
GGML_METAL_ADD_KERNEL(relu);
|
GGML_METAL_ADD_KERNEL(relu);
|
||||||
GGML_METAL_ADD_KERNEL(soft_max);
|
GGML_METAL_ADD_KERNEL(soft_max);
|
||||||
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
GGML_METAL_ADD_KERNEL(diag_mask_inf);
|
||||||
|
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
||||||
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
||||||
GGML_METAL_ADD_KERNEL(rms_norm);
|
GGML_METAL_ADD_KERNEL(rms_norm);
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
|
||||||
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
||||||
|
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
||||||
GGML_METAL_ADD_KERNEL(rope);
|
GGML_METAL_ADD_KERNEL(rope);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
||||||
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
|
||||||
|
@ -498,6 +500,14 @@ void ggml_metal_graph_compute(
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne02 == ne12);
|
||||||
|
|
||||||
|
nth0 = 64;
|
||||||
|
nth1 = 1;
|
||||||
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
|
@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
|
||||||
nth1 = 4;
|
nth1 = 4;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_F16:
|
|
||||||
{
|
|
||||||
GGML_ASSERT(ne02 == ne12);
|
|
||||||
|
|
||||||
nth0 = 32;
|
|
||||||
nth1 = 1;
|
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
|
||||||
} break;
|
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
||||||
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,6 +169,22 @@ kernel void kernel_diag_mask_inf(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_get_rows_f16(
|
||||||
|
device const void * src0,
|
||||||
|
device const int * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb1,
|
||||||
|
uint tpig[[thread_position_in_grid]]) {
|
||||||
|
const int i = tpig;
|
||||||
|
const int r = ((device int32_t *) src1)[i];
|
||||||
|
|
||||||
|
for (int j = 0; j < ne00; j++) {
|
||||||
|
dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_q4_0(
|
kernel void kernel_get_rows_q4_0(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const int * src1,
|
device const int * src1,
|
||||||
|
|
|
@ -961,7 +961,6 @@ static void llama_model_load_internal(
|
||||||
model.hparams = ml->file_loaders.at(0)->hparams;
|
model.hparams = ml->file_loaders.at(0)->hparams;
|
||||||
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
|
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
|
||||||
auto & hparams = model.hparams;
|
auto & hparams = model.hparams;
|
||||||
uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
|
@ -975,6 +974,8 @@ static void llama_model_load_internal(
|
||||||
hparams.n_ctx = n_ctx;
|
hparams.n_ctx = n_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
|
||||||
|
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
|
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
|
||||||
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
||||||
|
|
Loading…
Reference in a new issue