mirror of
https://git.adityakumar.xyz/llama.cpp.git
synced 2024-11-09 15:29:43 +00:00
Improve cuBLAS performance by using a memory pool (#1094)
* Improve cuBLAS performance by using a memory pool * Move cuda specific definitions to ggml-cuda.h/cu * Add CXX flags to nvcc * Change memory pool synchronization mechanism to a spin lock General code cleanup
This commit is contained in:
parent
25d7abbd1f
commit
50cb666b8a
4 changed files with 170 additions and 109 deletions
10
Makefile
10
Makefile
|
@ -101,11 +101,13 @@ ifdef LLAMA_OPENBLAS
|
|||
LDFLAGS += -lopenblas
|
||||
endif
|
||||
ifdef LLAMA_CUBLAS
|
||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
|
||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
|
||||
OBJS += ggml-cuda.o
|
||||
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include
|
||||
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64
|
||||
OBJS += ggml-cuda.o
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS = --forward-unknown-to-host-linker -arch=native
|
||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||
nvcc -arch=native -c -o $@ $<
|
||||
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -c $< -o $@
|
||||
endif
|
||||
ifdef LLAMA_GPROF
|
||||
CFLAGS += -pg
|
||||
|
|
116
ggml-cuda.cu
116
ggml-cuda.cu
|
@ -1,5 +1,7 @@
|
|||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <atomic>
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
typedef uint16_t ggml_fp16_t;
|
||||
|
@ -29,14 +31,12 @@ static_assert(sizeof(block_q4_2) == sizeof(ggml_fp16_t) + QK4_2 / 2, "wrong q4_2
|
|||
|
||||
#define QK4_3 16
|
||||
typedef struct {
|
||||
__half d; // delta
|
||||
__half m; // min
|
||||
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
||||
__half d; // delta
|
||||
__half m; // min
|
||||
uint8_t qs[QK4_3 / 2]; // nibbles / quants
|
||||
} block_q4_3;
|
||||
static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong q4_3 block size/padding");
|
||||
|
||||
|
||||
|
||||
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
|
@ -131,24 +131,98 @@ static __global__ void dequantize_block_q4_3(const void * vx, float * y) {
|
|||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
__host__ void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_0;
|
||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_0;
|
||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
__host__ void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_1;
|
||||
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_1;
|
||||
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
__host__ void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_2;
|
||||
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_2;
|
||||
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
__host__ void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_3;
|
||||
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
|
||||
void dequantize_row_q4_3_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_3;
|
||||
dequantize_block_q4_3<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
// buffer pool for cuda
|
||||
#define MAX_CUDA_BUFFERS 16
|
||||
|
||||
struct scoped_spin_lock {
|
||||
std::atomic_flag& lock;
|
||||
scoped_spin_lock(std::atomic_flag& lock) : lock(lock) {
|
||||
while (lock.test_and_set(std::memory_order_acquire)) {
|
||||
; // spin
|
||||
}
|
||||
}
|
||||
~scoped_spin_lock() {
|
||||
lock.clear(std::memory_order_release);
|
||||
}
|
||||
scoped_spin_lock(const scoped_spin_lock&) = delete;
|
||||
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
|
||||
};
|
||||
|
||||
struct cuda_buffer {
|
||||
void * ptr = nullptr;
|
||||
size_t size = 0;
|
||||
};
|
||||
|
||||
static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
|
||||
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
|
||||
|
||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
cuda_buffer& b = g_cuda_buffer_pool[i];
|
||||
if (b.size >= size && b.ptr != nullptr) {
|
||||
void * ptr = b.ptr;
|
||||
*actual_size = b.size;
|
||||
b.ptr = nullptr;
|
||||
b.size = 0;
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
void * ptr;
|
||||
CUDA_CHECK(cudaMalloc((void **) &ptr, size));
|
||||
*actual_size = size;
|
||||
return ptr;
|
||||
}
|
||||
|
||||
void ggml_cuda_pool_free(void * ptr, size_t size) {
|
||||
scoped_spin_lock lock(g_cuda_pool_lock);
|
||||
|
||||
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
|
||||
cuda_buffer& b = g_cuda_buffer_pool[i];
|
||||
if (b.ptr == nullptr) {
|
||||
b.ptr = ptr;
|
||||
b.size = size;
|
||||
return;
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
|
||||
CUDA_CHECK(cudaFree(ptr));
|
||||
}
|
||||
|
||||
cublasHandle_t g_cublasH = NULL;
|
||||
cudaStream_t g_cudaStream = NULL;
|
||||
|
||||
void ggml_init_cublas(void) {
|
||||
if (g_cublasH == NULL) {
|
||||
// create cublas handle, bind a stream
|
||||
CUBLAS_CHECK(cublasCreate(&g_cublasH));
|
||||
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStream, cudaStreamNonBlocking));
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, g_cudaStream));
|
||||
|
||||
// configure logging to stdout
|
||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
||||
}
|
||||
}
|
||||
|
|
29
ggml-cuda.h
29
ggml-cuda.h
|
@ -1,7 +1,36 @@
|
|||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define CUDA_CHECK(err) \
|
||||
do { \
|
||||
cudaError_t err_ = (err); \
|
||||
if (err_ != cudaSuccess) { \
|
||||
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CUBLAS_CHECK(err) \
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
extern cublasHandle_t g_cublasH;
|
||||
extern cudaStream_t g_cudaStream;
|
||||
|
||||
void ggml_init_cublas(void);
|
||||
void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size);
|
||||
void ggml_cuda_pool_free(void * ptr, size_t size);
|
||||
|
||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
||||
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
||||
void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream);
|
||||
|
|
124
ggml.c
124
ggml.c
|
@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) {
|
|||
#elif defined(GGML_USE_OPENBLAS)
|
||||
#include <cblas.h>
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include "ggml-cuda.h"
|
||||
|
||||
#define CUDA_CHECK(err) \
|
||||
do { \
|
||||
cudaError_t err_ = (err); \
|
||||
if (err_ != cudaSuccess) { \
|
||||
printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
|
||||
cudaGetErrorString(err_)); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define CUBLAS_CHECK(err) \
|
||||
do { \
|
||||
cublasStatus_t err_ = (err); \
|
||||
if (err_ != CUBLAS_STATUS_SUCCESS) { \
|
||||
printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
|
||||
exit(1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
static cublasHandle_t cublasH = NULL;
|
||||
static cudaStream_t cudaStream = NULL;
|
||||
static void init_cublas(void) {
|
||||
if (cublasH == NULL) {
|
||||
// create cublas handle, bind a stream
|
||||
CUBLAS_CHECK(cublasCreate(&cublasH));
|
||||
|
||||
CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
|
||||
|
||||
CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
|
||||
|
||||
// configure logging to stdout
|
||||
// CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#undef MIN
|
||||
|
@ -3748,7 +3711,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
|||
|
||||
// initialize cuBLAS
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
init_cublas();
|
||||
ggml_init_cublas();
|
||||
#endif
|
||||
|
||||
is_first_call = false;
|
||||
|
@ -7594,18 +7557,16 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
float *d_X = NULL;
|
||||
float *d_Y = NULL;
|
||||
float *d_D = NULL;
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne10;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||
size_t x_size, y_size, d_size;
|
||||
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||
#endif
|
||||
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
|
@ -7617,19 +7578,19 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// copy data to device
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, d_X, ne00,
|
||||
d_Y, ne10,
|
||||
&beta, d_D, ne01));
|
||||
|
||||
// copy data to host
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||
#else
|
||||
// zT = y * xT
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
|
@ -7641,10 +7602,10 @@ static void ggml_compute_forward_mul_mat_f32(
|
|||
}
|
||||
}
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaFree(d_X));
|
||||
CUDA_CHECK(cudaFree(d_Y));
|
||||
CUDA_CHECK(cudaFree(d_D));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
#endif
|
||||
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||
|
||||
|
@ -7794,18 +7755,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
#if defined(GGML_USE_CUBLAS)
|
||||
ggml_fp16_t * const wdata = params->wdata;
|
||||
|
||||
float *d_X = NULL;
|
||||
float *d_Y = NULL;
|
||||
float *d_D = NULL;
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne10;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||
size_t x_size, y_size, d_size;
|
||||
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||
#else
|
||||
float * const wdata = params->wdata;
|
||||
#endif
|
||||
|
@ -7839,12 +7798,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
|
||||
// copy data to device
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, d_X, CUDA_R_16F, ne00,
|
||||
d_Y, CUDA_R_16F, ne10,
|
||||
|
@ -7853,7 +7812,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
CUBLAS_GEMM_DEFAULT));
|
||||
|
||||
// copy data to host
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||
#else
|
||||
const float * x = wdata;
|
||||
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
|
||||
|
@ -7871,10 +7830,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaFree(d_X));
|
||||
CUDA_CHECK(cudaFree(d_Y));
|
||||
CUDA_CHECK(cudaFree(d_D));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
#endif
|
||||
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
|
||||
|
||||
|
@ -8042,20 +8001,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
float *d_X = NULL;
|
||||
float *d_Y = NULL;
|
||||
float *d_D = NULL;
|
||||
float *d_Q = NULL;
|
||||
const float alpha = 1.0f;
|
||||
const float beta = 0.0f;
|
||||
const int x_ne = ne01 * ne10;
|
||||
const int y_ne = ne11 * ne10;
|
||||
const int d_ne = ne11 * ne01;
|
||||
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
|
||||
CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
|
||||
size_t x_size, y_size, d_size, q_size;
|
||||
float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
|
||||
float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
|
||||
float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
|
||||
float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
|
||||
|
||||
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
|
||||
if (type == GGML_TYPE_Q4_0) {
|
||||
|
@ -8085,9 +8041,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
// copy and dequantize on device
|
||||
CUDA_CHECK(
|
||||
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
|
||||
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
|
||||
GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
|
||||
dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
#else
|
||||
{
|
||||
|
@ -8103,18 +8059,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// copy data to device
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, d_X, ne00,
|
||||
d_Y, ne10,
|
||||
&beta, d_D, ne01));
|
||||
|
||||
// copy data to host
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
|
||||
CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
|
||||
#else
|
||||
// zT = y * xT
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
||||
|
@ -8127,11 +8083,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
CUDA_CHECK(cudaStreamSynchronize(cudaStream));
|
||||
CUDA_CHECK(cudaFree(d_X));
|
||||
CUDA_CHECK(cudaFree(d_Y));
|
||||
CUDA_CHECK(cudaFree(d_D));
|
||||
CUDA_CHECK(cudaFree(d_Q));
|
||||
CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
ggml_cuda_pool_free(d_Q, q_size);
|
||||
#endif
|
||||
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
|
||||
|
||||
|
|
Loading…
Reference in a new issue