mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
Add LoRA support (#820)
This commit is contained in:
parent
efd05648c8
commit
315a95a4d3
10 changed files with 753 additions and 41 deletions
124
convert-lora-to-ggml.py
Normal file
124
convert-lora-to-ggml.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from typing import Any, Dict, Sequence, TextIO
|
||||
|
||||
import torch
|
||||
|
||||
from convert import DATA_TYPE_TO_FTYPE, NUMPY_TYPE_TO_DATA_TYPE, DataType
|
||||
|
||||
HF_SUBLAYER_TO_GGML = {
|
||||
"self_attn.q_proj": "attention.wq",
|
||||
"self_attn.k_proj": "attention.wk",
|
||||
"self_attn.v_proj": "attention.wv",
|
||||
"self_attn.o_proj": "attention.wo",
|
||||
"mlp.gate_proj": "feed_forward.w1",
|
||||
"mlp.down_proj": "feed_forward.w2",
|
||||
"mlp.up_proj": "feed_forward.w3",
|
||||
"input_layernorm": "attention_norm",
|
||||
"post_attention_layernorm": "ffn_norm",
|
||||
# "norm": "norm",
|
||||
# "embed_tokens": "tok_embeddings",
|
||||
# "lm_head": "output",
|
||||
}
|
||||
|
||||
|
||||
def translate_tensor_name(t: str) -> str:
|
||||
match = re.match(r".*layers\.(\d+)\.(\w+\.\w+)\.lora_(A|B)\.weight", t)
|
||||
if match:
|
||||
nn = match.group(1)
|
||||
sub_layer = match.group(2)
|
||||
lora_type = match.group(3)
|
||||
|
||||
sub_layer_renamed = HF_SUBLAYER_TO_GGML.get(sub_layer)
|
||||
if sub_layer_renamed is None:
|
||||
print(f"Error: unrecognized sub-layer {sub_layer} in tensor {t}")
|
||||
sys.exit(1)
|
||||
|
||||
output_string = (
|
||||
f"layers.{nn}.{HF_SUBLAYER_TO_GGML[sub_layer]}.weight.lora{lora_type}"
|
||||
)
|
||||
return output_string
|
||||
else:
|
||||
print(f"Error: unrecognized tensor {t}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def write_file_header(fout: TextIO, params: Dict[str, Any]) -> None:
|
||||
fout.write(b"ggla"[::-1]) # magic (ggml lora)
|
||||
fout.write(struct.pack("i", 1)) # file version
|
||||
fout.write(struct.pack("ii", params["r"], params["lora_alpha"]))
|
||||
|
||||
|
||||
def write_tensor_header(
|
||||
self, name: str, shape: Sequence[int], data_type: DataType
|
||||
) -> None:
|
||||
sname = name.encode("utf-8")
|
||||
fout.write(
|
||||
struct.pack(
|
||||
"iii",
|
||||
len(shape),
|
||||
len(sname),
|
||||
DATA_TYPE_TO_FTYPE[NUMPY_TYPE_TO_DATA_TYPE[data_type]],
|
||||
)
|
||||
)
|
||||
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
|
||||
fout.write(sname)
|
||||
fout.seek((fout.tell() + 31) & -32)
|
||||
|
||||
|
||||
if len(sys.argv) != 2:
|
||||
print(f"Usage: python {sys.argv[0]} <path>")
|
||||
print(
|
||||
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
input_json = os.path.join(sys.argv[1], "adapter_config.json")
|
||||
input_model = os.path.join(sys.argv[1], "adapter_model.bin")
|
||||
output_path = os.path.join(sys.argv[1], "ggml-adapter-model.bin")
|
||||
|
||||
model = torch.load(input_model, map_location="cpu")
|
||||
|
||||
with open(input_json, "r") as f:
|
||||
params = json.load(f)
|
||||
|
||||
if params["peft_type"] != "LORA":
|
||||
print(f"Error: unsupported adapter type {params['peft_type']}, expected LORA")
|
||||
sys.exit(1)
|
||||
|
||||
if params["fan_in_fan_out"] == True:
|
||||
print("Error: param fan_in_fan_out is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
if params["bias"] is not None and params["bias"] != "none":
|
||||
print("Error: param bias is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: these seem to be layers that have been trained but without lora.
|
||||
# doesn't seem widely used but eventually should be supported
|
||||
if params["modules_to_save"] is not None and len(params["modules_to_save"]) > 0:
|
||||
print("Error: param modules_to_save is not supported")
|
||||
sys.exit(1)
|
||||
|
||||
with open(output_path, "wb") as fout:
|
||||
fout.truncate()
|
||||
|
||||
write_file_header(fout, params)
|
||||
for k, v in model.items():
|
||||
if k.endswith("lora_A.weight"):
|
||||
if v.dtype != torch.float16 and v.dtype != torch.float32:
|
||||
v = v.float()
|
||||
v = v.T
|
||||
else:
|
||||
v = v.float()
|
||||
|
||||
t = v.numpy()
|
||||
tname = translate_tensor_name(k)
|
||||
print(f"{k} => {tname} {t.shape} {t.dtype} {t.nbytes/1024/1024:.2f}MB")
|
||||
write_tensor_header(fout, tname, t.shape, t.dtype)
|
||||
t.tofile(fout)
|
||||
|
||||
print(f"Converted {input_json} and {input_model} to {output_path}")
|
|
@ -139,6 +139,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.model = argv[i];
|
||||
} else if (arg == "--lora") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.lora_adapter = argv[i];
|
||||
params.use_mmap = false;
|
||||
} else if (arg == "--lora-base") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.lora_base = argv[i];
|
||||
} else if (arg == "-i" || arg == "--interactive") {
|
||||
params.interactive = true;
|
||||
} else if (arg == "--embedding") {
|
||||
|
@ -242,6 +255,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||
}
|
||||
fprintf(stderr, " --mtest compute maximum memory usage\n");
|
||||
fprintf(stderr, " --verbose-prompt print prompt before generation\n");
|
||||
fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
||||
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
||||
fprintf(stderr, "\n");
|
||||
|
|
|
@ -32,10 +32,11 @@ struct gpt_params {
|
|||
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
|
||||
std::string prompt = "";
|
||||
std::string input_prefix = ""; // string to prefix user inputs with
|
||||
|
||||
|
||||
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
|
||||
|
||||
std::string lora_adapter = ""; // lora adapter path
|
||||
std::string lora_base = ""; // base model path for the lora adapter
|
||||
|
||||
bool memory_f16 = true; // use f16 instead of f32 for memory kv
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
|
|
|
@ -114,6 +114,17 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!params.lora_adapter.empty()) {
|
||||
int err = llama_apply_lora_from_file(ctx,
|
||||
params.lora_adapter.c_str(),
|
||||
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
|
||||
params.n_threads);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
|
|
|
@ -134,6 +134,17 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
if (!params.lora_adapter.empty()) {
|
||||
int err = llama_apply_lora_from_file(ctx,
|
||||
params.lora_adapter.c_str(),
|
||||
params.lora_base.empty() ? NULL : params.lora_base.c_str(),
|
||||
params.n_threads);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// print system information
|
||||
{
|
||||
fprintf(stderr, "\n");
|
||||
|
|
327
ggml.c
327
ggml.c
|
@ -1420,6 +1420,34 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
|||
#endif
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
|
||||
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_Q4_0] = {
|
||||
.dequantize_row_q = dequantize_row_q4_0,
|
||||
.quantize_row_q = quantize_row_q4_0,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
||||
.quantize_row_q_dot = quantize_row_q8_0,
|
||||
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_1] = {
|
||||
.dequantize_row_q = dequantize_row_q4_1,
|
||||
.quantize_row_q = quantize_row_q4_1,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
||||
.quantize_row_q_dot = quantize_row_q4_1,
|
||||
.vec_dot_q = ggml_vec_dot_q4_1,
|
||||
},
|
||||
// TODO: GGML_TYPE_Q8_0
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
||||
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
||||
return quantize_fns[i];
|
||||
}
|
||||
|
||||
|
||||
//
|
||||
// simd mappings
|
||||
//
|
||||
|
@ -5588,6 +5616,26 @@ static void ggml_compute_forward_dup_f16(
|
|||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
||||
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
||||
size_t id = 0;
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
||||
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
||||
float * src0_f32 = (float *) params->wdata;
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
const ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
// convert to f32 and quantize
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
src0_f32[i00] = GGML_FP16_TO_FP32(src0_ptr[i00]);
|
||||
}
|
||||
quantize_row_q(src0_f32, dst_ptr + id, ne00);
|
||||
id += dst_row_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
}
|
||||
|
@ -5780,6 +5828,21 @@ static void ggml_compute_forward_dup_f32(
|
|||
}
|
||||
}
|
||||
}
|
||||
} else if (dst->type == GGML_TYPE_Q4_0 || dst->type == GGML_TYPE_Q4_1) {
|
||||
quantize_row_q_t const quantize_row_q = quantize_fns[dst->type].quantize_row_q;
|
||||
size_t id = 0;
|
||||
uint8_t * dst_ptr = (uint8_t *) dst->data;
|
||||
size_t dst_row_size = nb0 * (ne00 / GGML_BLCK_SIZE[dst->type]);
|
||||
|
||||
for (int i03 = 0; i03 < ne03; i03++) {
|
||||
for (int i02 = 0; i02 < ne02; i02++) {
|
||||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
quantize_row_q(src0_ptr, dst_ptr + id, ne00);
|
||||
id += dst_row_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_ASSERT(false); // TODO: implement
|
||||
}
|
||||
|
@ -5968,6 +6031,212 @@ static void ggml_compute_forward_add_f32(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_f16_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(float)) {
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
float * src1_ptr = (float *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + *src1_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// src1 is not contiguous
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_f16_f16(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int n = ggml_nrows(src0);
|
||||
const int nc = src0->ne[0];
|
||||
|
||||
const size_t nb00 = src0->nb[0];
|
||||
const size_t nb01 = src0->nb[1];
|
||||
|
||||
const size_t nb10 = src1->nb[0];
|
||||
const size_t nb11 = src1->nb[1];
|
||||
|
||||
const size_t nb0 = dst->nb[0];
|
||||
const size_t nb1 = dst->nb[1];
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(dst->type == GGML_TYPE_F16);
|
||||
|
||||
GGML_ASSERT( nb0 == sizeof(ggml_fp16_t));
|
||||
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
|
||||
|
||||
if (nb10 == sizeof(ggml_fp16_t)) {
|
||||
for (int j = ith; j < n; j += nth) {
|
||||
ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + j*nb1);
|
||||
ggml_fp16_t * src0_ptr = (ggml_fp16_t *) ((char *) src0->data + j*nb01);
|
||||
for (int i = 0; i < nc; i++) {
|
||||
ggml_fp16_t * src1_ptr = (ggml_fp16_t *) ((char *) src1->data + j*nb11 + i*nb10);
|
||||
dst_ptr[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(src0_ptr[i]) + GGML_FP16_TO_FP32(*src1_ptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// src1 is not contiguous
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add_q_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
const struct ggml_tensor * src1,
|
||||
struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst));
|
||||
|
||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
const int64_t ne03 = src0->ne[3];
|
||||
|
||||
//const int64_t ne10 = src1->ne[0];
|
||||
//const int64_t ne11 = src1->ne[1];
|
||||
const int64_t ne12 = src1->ne[2];
|
||||
const int64_t ne13 = src1->ne[3];
|
||||
|
||||
//const int64_t ne0 = dst->ne[0];
|
||||
//const int64_t ne1 = dst->ne[1];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
const int nb00 = src0->nb[0];
|
||||
const int nb01 = src0->nb[1];
|
||||
const int nb02 = src0->nb[2];
|
||||
const int nb03 = src0->nb[3];
|
||||
|
||||
const int nb10 = src1->nb[0];
|
||||
const int nb11 = src1->nb[1];
|
||||
const int nb12 = src1->nb[2];
|
||||
const int nb13 = src1->nb[3];
|
||||
|
||||
const int nb0 = dst->nb[0];
|
||||
const int nb1 = dst->nb[1];
|
||||
const int nb2 = dst->nb[2];
|
||||
const int nb3 = dst->nb[3];
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
|
||||
quantize_row_q_t const quantize_row_q = quantize_fns[type].quantize_row_q;
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == (int) GGML_TYPE_SIZE[type]);
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 <= nb1);
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1);
|
||||
GGML_ASSERT(dst->type == src0->type);
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
||||
|
||||
// total rows in src0
|
||||
const int nr = ne01*ne02*ne03;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float * wdata = (float*) params->wdata + ne00 * ith;
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0 indices
|
||||
const int i03 = ir/(ne02*ne01);
|
||||
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
// src1 and dst are same shape as src0 => same indices
|
||||
const int i13 = i03;
|
||||
const int i12 = i02;
|
||||
const int i11 = i01;
|
||||
|
||||
const int i3 = i03;
|
||||
const int i2 = i02;
|
||||
const int i1 = i01;
|
||||
|
||||
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
||||
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
||||
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
|
||||
|
||||
assert(ne00 % 32 == 0);
|
||||
|
||||
// unquantize row from src0 to temp buffer
|
||||
dequantize_row_q(src0_row, wdata, ne00);
|
||||
// add src1
|
||||
ggml_vec_acc_f32(ne00, wdata, src1_row);
|
||||
// quantize row to dst
|
||||
quantize_row_q(wdata, dst_row, ne00);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_add(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
|
@ -5978,6 +6247,23 @@ static void ggml_compute_forward_add(
|
|||
{
|
||||
ggml_compute_forward_add_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
if (src1->type == GGML_TYPE_F16) {
|
||||
ggml_compute_forward_add_f16_f16(params, src0, src1, dst);
|
||||
}
|
||||
else if (src1->type == GGML_TYPE_F32) {
|
||||
ggml_compute_forward_add_f16_f32(params, src0, src1, dst);
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
{
|
||||
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -7257,30 +7543,6 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
//}
|
||||
}
|
||||
|
||||
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_Q4_0] = {
|
||||
.dequantize_row_q = dequantize_row_q4_0,
|
||||
.quantize_row_q = quantize_row_q4_0,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
|
||||
.quantize_row_q_dot = quantize_row_q8_0,
|
||||
.vec_dot_q = ggml_vec_dot_q4_0_q8_0,
|
||||
},
|
||||
[GGML_TYPE_Q4_1] = {
|
||||
.dequantize_row_q = dequantize_row_q4_1,
|
||||
.quantize_row_q = quantize_row_q4_1,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
||||
.quantize_row_q_dot = quantize_row_q4_1,
|
||||
.vec_dot_q = ggml_vec_dot_q4_1,
|
||||
},
|
||||
// TODO: GGML_TYPE_Q8_0
|
||||
};
|
||||
|
||||
// For internal test use
|
||||
quantize_fns_t ggml_internal_get_quantize_fn(size_t i) {
|
||||
GGML_ASSERT(i < GGML_TYPE_COUNT);
|
||||
return quantize_fns[i];
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat_q_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
const struct ggml_tensor * src0,
|
||||
|
@ -10137,13 +10399,29 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
switch (node->op) {
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_DUP:
|
||||
{
|
||||
node->n_tasks = 1;
|
||||
|
||||
size_t cur = 0;
|
||||
if (node->type == GGML_TYPE_Q4_0 || node->type == GGML_TYPE_Q4_1) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0];
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
node->n_tasks = n_threads;
|
||||
|
||||
size_t cur = 0;
|
||||
|
||||
if (node->src0->type == GGML_TYPE_Q4_0 || node->src0->type == GGML_TYPE_Q4_1) {
|
||||
cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads;
|
||||
}
|
||||
|
||||
work_size = MAX(work_size, cur);
|
||||
} break;
|
||||
case GGML_OP_SUB:
|
||||
case GGML_OP_MUL:
|
||||
|
@ -10224,7 +10502,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
{
|
||||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_VIEW:
|
||||
|
|
6
ggml.h
6
ggml.h
|
@ -430,6 +430,12 @@ struct ggml_tensor * ggml_add(
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
|
||||
struct ggml_tensor * ggml_add_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
struct ggml_tensor * ggml_sub(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
|
251
llama.cpp
251
llama.cpp
|
@ -1,6 +1,8 @@
|
|||
// Defines fileno on msys:
|
||||
#ifndef _GNU_SOURCE
|
||||
#define _GNU_SOURCE
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#endif
|
||||
|
||||
#include "llama_util.h"
|
||||
|
@ -633,6 +635,7 @@ struct llama_model_loader {
|
|||
throw format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s",
|
||||
name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
|
||||
}
|
||||
|
||||
return get_tensor_for(lt);
|
||||
}
|
||||
|
||||
|
@ -1774,6 +1777,254 @@ int llama_model_quantize(
|
|||
}
|
||||
}
|
||||
|
||||
int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
|
||||
fprintf(stderr, "%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
|
||||
|
||||
auto & model = ctx->model;
|
||||
|
||||
const int64_t t_start_lora_us = ggml_time_us();
|
||||
|
||||
auto fin = std::ifstream(path_lora, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, path_lora);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// verify magic and version
|
||||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 'ggla') {
|
||||
fprintf(stderr, "%s: bad file magic\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
uint32_t format_version;
|
||||
fin.read((char *) &format_version, sizeof(format_version));
|
||||
|
||||
if (format_version != 1) {
|
||||
fprintf(stderr, "%s: unsupported file version\n", __func__ );
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t lora_r;
|
||||
int32_t lora_alpha;
|
||||
fin.read((char *) &lora_r, sizeof(lora_r));
|
||||
fin.read((char *) &lora_alpha, sizeof(lora_alpha));
|
||||
float scaling = (float)lora_alpha / (float)lora_r;
|
||||
|
||||
fprintf(stderr, "%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
|
||||
|
||||
|
||||
// create a temporary ggml context to store the lora tensors
|
||||
// todo: calculate size from biggest possible tensor
|
||||
std::vector<uint8_t> lora_buf(1024ull * 1024ull * 1024ull);
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = lora_buf.size();
|
||||
params.mem_buffer = lora_buf.data();
|
||||
params.no_alloc = false;
|
||||
|
||||
ggml_context * lora_ctx = ggml_init(params);
|
||||
std::unordered_map<std::string, struct ggml_tensor *> lora_tensors;
|
||||
|
||||
// create a name -> tensor map of the model to accelerate lookups
|
||||
std::unordered_map<std::string, struct ggml_tensor*> model_tensors;
|
||||
for (auto & kv: model.tensors_by_name) {
|
||||
model_tensors.insert(kv);
|
||||
}
|
||||
|
||||
|
||||
// load base model
|
||||
std::unique_ptr<llama_model_loader> model_loader;
|
||||
ggml_context * base_ctx = NULL;
|
||||
llama_buffer base_buf;
|
||||
if (path_base_model) {
|
||||
fprintf(stderr, "%s: loading base model from '%s'\n", __func__, path_base_model);
|
||||
model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*vocab_only*/ false));
|
||||
|
||||
size_t ctx_size, mmapped_size;
|
||||
model_loader->calc_sizes(&ctx_size, &mmapped_size);
|
||||
base_buf.resize(ctx_size);
|
||||
|
||||
ggml_init_params base_params;
|
||||
base_params.mem_size = base_buf.size;
|
||||
base_params.mem_buffer = base_buf.addr;
|
||||
base_params.no_alloc = model_loader->use_mmap;
|
||||
|
||||
base_ctx = ggml_init(base_params);
|
||||
|
||||
model_loader->ggml_ctx = base_ctx;
|
||||
|
||||
// maybe this should in llama_model_loader
|
||||
if (model_loader->use_mmap) {
|
||||
model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false));
|
||||
}
|
||||
}
|
||||
|
||||
// read tensors and apply
|
||||
bool warned = false;
|
||||
int n_tensors = 0;
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ftype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ftype), sizeof(ftype));
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
// check for lora suffix and get the type of tensor
|
||||
const std::string lora_suffix = ".lora";
|
||||
size_t pos = name.rfind(lora_suffix);
|
||||
if (pos == std::string::npos) {
|
||||
fprintf(stderr, "%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::string lora_type = name.substr(pos + lora_suffix.length());
|
||||
std::string base_name = name;
|
||||
base_name.erase(pos);
|
||||
// fprintf(stderr, "%s: %s => %s (lora type %s) ", __func__, name.c_str(),base_name.c_str(), lora_type.c_str());
|
||||
|
||||
if (model_tensors.find(base_name.data()) == model_tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in lora adapter\n", __func__, name.data());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// create ggml tensor
|
||||
ggml_type wtype;
|
||||
switch (ftype) {
|
||||
case 0: wtype = GGML_TYPE_F32; break;
|
||||
case 1: wtype = GGML_TYPE_F16; break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid tensor data type '%d'\n",
|
||||
__func__, ftype);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
ggml_tensor* lora_tensor;
|
||||
if (n_dims == 2) {
|
||||
lora_tensor = ggml_new_tensor_2d(lora_ctx, wtype, ne[0], ne[1]);
|
||||
}
|
||||
else {
|
||||
fprintf(stderr, "%s: unsupported tensor dimension %d\n", __func__, n_dims);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// load tensor data
|
||||
size_t offset = fin.tellg();
|
||||
size_t tensor_data_size = ggml_nbytes(lora_tensor);
|
||||
offset = (offset + 31) & -32;
|
||||
fin.seekg(offset);
|
||||
fin.read((char*)lora_tensor->data, tensor_data_size);
|
||||
|
||||
lora_tensors[name] = lora_tensor;
|
||||
|
||||
// check if we have both A and B tensors and apply
|
||||
if (lora_tensors.find(base_name + ".loraA") != lora_tensors.end() &&
|
||||
lora_tensors.find(base_name + ".loraB") != lora_tensors.end()) {
|
||||
|
||||
ggml_tensor * dest_t = model_tensors[base_name];
|
||||
ggml_tensor * base_t;
|
||||
if (model_loader) {
|
||||
// load from base model
|
||||
if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) {
|
||||
fprintf(stderr, "%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
|
||||
return 1;
|
||||
}
|
||||
size_t idx = model_loader->tensors_map.name_to_idx[base_name];
|
||||
llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
|
||||
base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] });
|
||||
lt.data = (uint8_t *) lt.ggml_tensor->data;
|
||||
model_loader->load_data_for(lt);
|
||||
lt.ggml_tensor->data = lt.data;
|
||||
}
|
||||
else {
|
||||
base_t = dest_t;
|
||||
}
|
||||
|
||||
if (base_t->type == GGML_TYPE_Q4_0 || base_t->type == GGML_TYPE_Q4_1) {
|
||||
if (!warned) {
|
||||
fprintf(stderr, "%s: warning: using a lora adapter with a quantized model may result in poor quality, "
|
||||
"use a f16 or f32 base model with --lora-base\n", __func__);
|
||||
warned = true;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_tensor * loraA = lora_tensors[base_name + ".loraA"];
|
||||
ggml_tensor * loraB = lora_tensors[base_name + ".loraB"];
|
||||
|
||||
if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
|
||||
fprintf(stderr, "%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
|
||||
" are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// w = w + BA*s
|
||||
ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
|
||||
|
||||
if (scaling != 1.0f) {
|
||||
ggml_tensor * scale_tensor = ggml_new_f32(lora_ctx, scaling);
|
||||
BA = ggml_scale(lora_ctx, BA, scale_tensor);
|
||||
}
|
||||
|
||||
ggml_tensor * r;
|
||||
if (base_t == dest_t) {
|
||||
r = ggml_add_inplace(lora_ctx, dest_t, BA);
|
||||
}
|
||||
else {
|
||||
r = ggml_add(lora_ctx, base_t, BA);
|
||||
r = ggml_cpy(lora_ctx, r, dest_t);
|
||||
}
|
||||
|
||||
struct ggml_cgraph gf = ggml_build_forward(r);
|
||||
gf.n_threads = n_threads;
|
||||
ggml_graph_compute(lora_ctx, &gf);
|
||||
|
||||
// we won't need these tensors again, reset the context to save memory
|
||||
ggml_free(lora_ctx);
|
||||
lora_ctx = ggml_init(params);
|
||||
lora_tensors.clear();
|
||||
|
||||
n_tensors++;
|
||||
if (n_tensors % 4 == 0)
|
||||
fprintf(stderr, ".");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this should be in a destructor, it will leak on failure
|
||||
ggml_free(lora_ctx);
|
||||
if (base_ctx) {
|
||||
ggml_free(base_ctx);
|
||||
}
|
||||
|
||||
const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
|
||||
fprintf(stderr, " done (%.2f ms)\n", t_lora_us / 1000.0);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) {
|
||||
try {
|
||||
return llama_apply_lora_from_file_internal(ctx, path_lora, path_base_model, n_threads);
|
||||
} catch (const std::string & err) {
|
||||
fprintf(stderr, "%s: failed to apply lora adapter: %s\n", __func__, err.c_str());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the KV cache that will contain the context for the
|
||||
// ongoing prediction with the model.
|
||||
const uint8_t * llama_get_kv_cache(struct llama_context * ctx) {
|
||||
|
|
12
llama.h
12
llama.h
|
@ -96,6 +96,18 @@ extern "C" {
|
|||
const char * fname_out,
|
||||
enum llama_ftype ftype);
|
||||
|
||||
// Apply a LoRA adapter to a loaded model
|
||||
// path_base_model is the path to a higher quality model to use as a base for
|
||||
// the layers modified by the adapter. Can be NULL to use the current loaded model.
|
||||
// The model needs to be reloaded before applying a new adapter, otherwise the adapter
|
||||
// will be applied on top of the previous one
|
||||
// Returns 0 on success
|
||||
LLAMA_API int llama_apply_lora_from_file(
|
||||
struct llama_context * ctx,
|
||||
const char * path_lora,
|
||||
const char * path_base_model,
|
||||
int n_threads);
|
||||
|
||||
// Returns the KV cache that will contain the context for the
|
||||
// ongoing prediction with the model.
|
||||
LLAMA_API const uint8_t * llama_get_kv_cache(struct llama_context * ctx);
|
||||
|
|
|
@ -168,7 +168,7 @@ struct llama_mmap {
|
|||
#ifdef _POSIX_MAPPED_FILES
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file) {
|
||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
||||
size = file->size;
|
||||
int fd = fileno(file->fp);
|
||||
int flags = MAP_SHARED;
|
||||
|
@ -180,12 +180,14 @@ struct llama_mmap {
|
|||
throw format("mmap failed: %s", strerror(errno));
|
||||
}
|
||||
|
||||
if (prefetch) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
if (madvise(addr, file->size, MADV_WILLNEED)) {
|
||||
fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n",
|
||||
strerror(errno));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~llama_mmap() {
|
||||
munmap(addr, size);
|
||||
|
@ -193,7 +195,7 @@ struct llama_mmap {
|
|||
#elif defined(_WIN32)
|
||||
static constexpr bool SUPPORTED = true;
|
||||
|
||||
llama_mmap(struct llama_file * file) {
|
||||
llama_mmap(struct llama_file * file, bool prefetch = true) {
|
||||
size = file->size;
|
||||
|
||||
HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp));
|
||||
|
@ -215,6 +217,7 @@ struct llama_mmap {
|
|||
}
|
||||
|
||||
#if _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
if (prefetch) {
|
||||
// Advise the kernel to preload the mapped memory
|
||||
WIN32_MEMORY_RANGE_ENTRY range;
|
||||
range.VirtualAddress = addr;
|
||||
|
@ -223,6 +226,7 @@ struct llama_mmap {
|
|||
fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n",
|
||||
llama_format_win_err(GetLastError()).c_str());
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma message("warning: You are building for pre-Windows 8; prefetch not supported")
|
||||
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
|
||||
|
|
Loading…
Reference in a new issue