Custom RoPE + bettter memory management for CUDA (#2295)

* Custom RoPE + bettter memory management for CUDA

* Adjusted look ahead in ggml_cuda_pool_malloc to 5%

This is sufficient it seems.
We end up using about 200 MB less VRAM that way when running
the 13B model with context 8192.

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2023-07-21 17:27:51 +03:00 committed by GitHub
parent 4d76a5f49b
commit d924522a46
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -2423,10 +2423,26 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock); scoped_spin_lock lock(g_cuda_pool_lock);
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
size_t max_size = 0, tot_size = 0;
#endif
size_t best_diff = 1ull << 36;
int ibest = -1;
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[id][i]; cuda_buffer& b = g_cuda_buffer_pool[id][i];
if (b.size >= size && b.ptr != nullptr) { if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
tot_size += b.size;
if (b.size > max_size) max_size = b.size;
#endif
if (b.size >= size) {
size_t diff = b.size - size;
if (diff < best_diff) {
best_diff = diff;
ibest = i;
if (!best_diff) {
void * ptr = b.ptr; void * ptr = b.ptr;
*actual_size = b.size; *actual_size = b.size;
b.ptr = nullptr; b.ptr = nullptr;
@ -2434,9 +2450,26 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
return ptr; return ptr;
} }
} }
}
}
}
if (ibest >= 0) {
cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
void * ptr = b.ptr;
*actual_size = b.size;
b.ptr = nullptr;
b.size = 0;
return ptr;
}
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
#endif
void * ptr; void * ptr;
CUDA_CHECK(cudaMalloc((void **) &ptr, size)); size_t look_ahead_size = (size_t) (1.05 * size);
*actual_size = size; look_ahead_size = 256 * ((look_ahead_size + 255)/256);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
return ptr; return ptr;
} }
@ -2955,8 +2988,13 @@ inline void ggml_cuda_op_rope(
const int mode = ((int32_t *) src1->data)[2]; const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3]; const int n_ctx = ((int32_t *) src1->data)[3];
const float theta_scale = powf(10000.0, -2.0f/n_dims); // RoPE alteration for extended context
const float p = ((mode & 1) == 0 ? n_past + i02 : i02); float freq_base, freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
bool is_glm = mode & 4; bool is_glm = mode & 4;