metal : add checks for buffer size (#1706)

Co-authored-by: Spencer Sutton <Spencer.Sutton@precisely.com>
This commit is contained in:
Spencer Sutton 2023-06-05 23:28:17 -04:00 committed by GitHub
parent f4c55d3bd7
commit 590250f7a9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 7 deletions

View file

@ -204,6 +204,11 @@ bool ggml_metal_add_buffer(
ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = data;
ctx->buffers[ctx->n_buffers].size = size;
if (ctx->device.maxBufferLength < aligned_size) {
fprintf(stderr, "%s: buffer '%s' size %zu is larger than buffer maximum of %zu\n", __func__, name, aligned_size, ctx->device.maxBufferLength);
return false;
}
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:aligned_size options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {

View file

@ -2405,17 +2405,30 @@ struct llama_context * llama_init_from_file(
// this allocates all Metal resources and memory buffers
ctx->ctx_metal = ggml_metal_init();
void *data_ptr = NULL;
size_t data_size = 0;
if (params.use_mmap) {
ggml_metal_add_buffer(ctx->ctx_metal, "data", ctx->model.mapping->addr, ctx->model.mapping->size);
ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size);
data_ptr = ctx->model.mapping->addr;
data_size= ctx->model.mapping->size;
} else {
ggml_metal_add_buffer(ctx->ctx_metal, "data", ggml_get_mem_buffer(ctx->model.ctx), ggml_get_mem_size(ctx->model.ctx));
ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size);
data_ptr = ggml_get_mem_buffer(ctx->model.ctx);
data_size= ggml_get_mem_size(ctx->model.ctx);
}
ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size);
ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size);
ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size);
#define LLAMA_METAL_CHECK_BUF(result) \
if (!(result)) { \
fprintf(stderr, "%s: failed to add buffer\n", __func__); \
llama_free(ctx); \
return NULL; \
}
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->model.kv_self.buf.addr, ctx->model.kv_self.buf.size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr0", ctx->buf_scratch[0].addr, ctx->buf_scratch[0].size));
LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "scr1", ctx->buf_scratch[1].addr, ctx->buf_scratch[1].size));
#undef LLAMA_METAL_CHECK_BUF
}
#endif