aboutsummaryrefslogtreecommitdiff
path: root/ggml.c
diff options
context:
space:
mode:
authorslaren <2141330+slaren@users.noreply.github.com>2023-04-21 21:59:17 +0200
committerGitHub <noreply@github.com>2023-04-21 21:59:17 +0200
commit50cb666b8a2e35a49b08c0f6bc81138c8f6f2ac1 (patch)
tree80370baa4d8b17d2cb44a134bed6b1a088b1cfc1 /ggml.c
parent25d7abbd1f73582b7e0fdc422a936e8541c0780b (diff)
Improve cuBLAS performance by using a memory pool (#1094)
* Improve cuBLAS performance by using a memory pool * Move cuda specific definitions to ggml-cuda.h/cu * Add CXX flags to nvcc * Change memory pool synchronization mechanism to a spin lock General code cleanup
Diffstat (limited to 'ggml.c')
-rw-r--r--ggml.c124
1 files changed, 40 insertions, 84 deletions
diff --git a/ggml.c b/ggml.c
index 6cea937..2ea4e68 100644
--- a/ggml.c
+++ b/ggml.c
@@ -148,44 +148,7 @@ inline static void* ggml_aligned_malloc(size_t size) {
#elif defined(GGML_USE_OPENBLAS)
#include <cblas.h>
#elif defined(GGML_USE_CUBLAS)
-#include <cublas_v2.h>
-#include <cuda_runtime.h>
#include "ggml-cuda.h"
-
-#define CUDA_CHECK(err) \
- do { \
- cudaError_t err_ = (err); \
- if (err_ != cudaSuccess) { \
- printf("CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
- cudaGetErrorString(err_)); \
- exit(1); \
- } \
- } while (0)
-
-#define CUBLAS_CHECK(err) \
- do { \
- cublasStatus_t err_ = (err); \
- if (err_ != CUBLAS_STATUS_SUCCESS) { \
- printf("cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
- exit(1); \
- } \
- } while (0)
-
-static cublasHandle_t cublasH = NULL;
-static cudaStream_t cudaStream = NULL;
-static void init_cublas(void) {
- if (cublasH == NULL) {
- // create cublas handle, bind a stream
- CUBLAS_CHECK(cublasCreate(&cublasH));
-
- CUDA_CHECK(cudaStreamCreateWithFlags(&cudaStream, cudaStreamNonBlocking));
-
- CUBLAS_CHECK(cublasSetStream(cublasH, cudaStream));
-
- // configure logging to stdout
- // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, NULL));
- }
-}
#endif
#undef MIN
@@ -3748,7 +3711,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
// initialize cuBLAS
#if defined(GGML_USE_CUBLAS)
- init_cublas();
+ ggml_init_cublas();
#endif
is_first_call = false;
@@ -7594,18 +7557,16 @@ static void ggml_compute_forward_mul_mat_f32(
}
#if defined(GGML_USE_CUBLAS)
- float *d_X = NULL;
- float *d_Y = NULL;
- float *d_D = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+ size_t x_size, y_size, d_size;
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#endif
for (int64_t i03 = 0; i03 < ne03; i03++) {
@@ -7617,19 +7578,19 @@ static void ggml_compute_forward_mul_mat_f32(
#if defined(GGML_USE_CUBLAS)
// copy data to device
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, cudaStream));
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(float) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
// compute
CUBLAS_CHECK(
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -7641,10 +7602,10 @@ static void ggml_compute_forward_mul_mat_f32(
}
}
#if defined(GGML_USE_CUBLAS)
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
- CUDA_CHECK(cudaFree(d_X));
- CUDA_CHECK(cudaFree(d_Y));
- CUDA_CHECK(cudaFree(d_D));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+ ggml_cuda_pool_free(d_X, x_size);
+ ggml_cuda_pool_free(d_Y, y_size);
+ ggml_cuda_pool_free(d_D, d_size);
#endif
//printf("CBLAS F32 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);
@@ -7794,18 +7755,16 @@ static void ggml_compute_forward_mul_mat_f16_f32(
#if defined(GGML_USE_CUBLAS)
ggml_fp16_t * const wdata = params->wdata;
- float *d_X = NULL;
- float *d_Y = NULL;
- float *d_D = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(ggml_fp16_t) * x_ne));
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
+ size_t x_size, y_size, d_size;
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
#else
float * const wdata = params->wdata;
#endif
@@ -7839,12 +7798,12 @@ static void ggml_compute_forward_mul_mat_f16_f32(
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
// copy data to device
- CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, cudaStream));
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d_X, x, sizeof(ggml_fp16_t) * x_ne, cudaMemcpyHostToDevice, g_cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(ggml_fp16_t) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
// compute
CUBLAS_CHECK(
- cublasGemmEx(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+ cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, CUDA_R_16F, ne00,
d_Y, CUDA_R_16F, ne10,
@@ -7853,7 +7812,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
CUBLAS_GEMM_DEFAULT));
// copy data to host
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#else
const float * x = wdata;
const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -7871,10 +7830,10 @@ static void ggml_compute_forward_mul_mat_f16_f32(
}
#if defined(GGML_USE_CUBLAS)
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
- CUDA_CHECK(cudaFree(d_X));
- CUDA_CHECK(cudaFree(d_Y));
- CUDA_CHECK(cudaFree(d_D));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+ ggml_cuda_pool_free(d_X, x_size);
+ ggml_cuda_pool_free(d_Y, y_size);
+ ggml_cuda_pool_free(d_D, d_size);
#endif
/*printf("CBLAS F16 = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);*/
@@ -8042,20 +8001,17 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
#if defined(GGML_USE_CUBLAS)
- float *d_X = NULL;
- float *d_Y = NULL;
- float *d_D = NULL;
- float *d_Q = NULL;
const float alpha = 1.0f;
const float beta = 0.0f;
const int x_ne = ne01 * ne10;
const int y_ne = ne11 * ne10;
const int d_ne = ne11 * ne01;
- CUDA_CHECK(cudaMalloc((void **)(&d_X), sizeof(float) * x_ne));
- CUDA_CHECK(cudaMalloc((void **)(&d_Y), sizeof(float) * y_ne));
- CUDA_CHECK(cudaMalloc((void **)(&d_D), sizeof(float) * d_ne));
- CUDA_CHECK(cudaMalloc((void **)(&d_Q), GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type]));
+ size_t x_size, y_size, d_size, q_size;
+ float *d_X = ggml_cuda_pool_malloc(sizeof(float) * x_ne, &x_size);
+ float *d_Y = ggml_cuda_pool_malloc(sizeof(float) * y_ne, &y_size);
+ float *d_D = ggml_cuda_pool_malloc(sizeof(float) * d_ne, &d_size);
+ float *d_Q = ggml_cuda_pool_malloc(GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], &q_size);
void (*dequantize_row_q_cuda)(const void * x, float * y, int k, cudaStream_t stream) = NULL;
if (type == GGML_TYPE_Q4_0) {
@@ -8085,9 +8041,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
// copy and dequantize on device
CUDA_CHECK(
cudaMemcpyAsync(d_Q, (char *) src0->data + i03*nb03 + i02*nb02,
- GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, cudaStream));
+ GGML_TYPE_SIZE[type] * x_ne / GGML_BLCK_SIZE[type], cudaMemcpyHostToDevice, g_cudaStream));
- dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, cudaStream);
+ dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
CUDA_CHECK(cudaGetLastError());
#else
{
@@ -8103,18 +8059,18 @@ static void ggml_compute_forward_mul_mat_q_f32(
#if defined(GGML_USE_CUBLAS)
// copy data to device
- CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d_Y, y, sizeof(float) * y_ne, cudaMemcpyHostToDevice, g_cudaStream));
// compute
CUBLAS_CHECK(
- cublasSgemm(cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+ cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
&alpha, d_X, ne00,
d_Y, ne10,
&beta, d_D, ne01));
// copy data to host
- CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+ CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
#else
// zT = y * xT
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
@@ -8127,11 +8083,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
#if defined(GGML_USE_CUBLAS)
- CUDA_CHECK(cudaStreamSynchronize(cudaStream));
- CUDA_CHECK(cudaFree(d_X));
- CUDA_CHECK(cudaFree(d_Y));
- CUDA_CHECK(cudaFree(d_D));
- CUDA_CHECK(cudaFree(d_Q));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStream));
+ ggml_cuda_pool_free(d_X, x_size);
+ ggml_cuda_pool_free(d_Y, y_size);
+ ggml_cuda_pool_free(d_D, d_size);
+ ggml_cuda_pool_free(d_Q, q_size);
#endif
//printf("CBLAS = %f ms, %d x %d x %d x %d\n", (ggml_perf_time_us() - t0)/1000.0, ne0, ne1, ne2, ne3);