diff options
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 24 |
1 files changed, 14 insertions, 10 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index bcdff36..f11fbe5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3898,10 +3898,9 @@ static size_t g_scratch_offset = 0; static int g_device_count = -1; static int g_main_device = 0; -#ifndef GGML_CUDA_FORCE_DMMV static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; -#endif static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; +static bool g_mul_mat_q = false; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -3923,9 +3922,7 @@ void ggml_init_cublas() { g_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; -#ifndef GGML_CUDA_FORCE_DMMV g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; -#endif } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; @@ -4278,6 +4275,7 @@ inline void ggml_cuda_op_mul_mat_vec( #ifdef GGML_CUDA_FORCE_DMMV const bool use_mul_mat_vec_q = false; + (void) g_compute_capabilities[0]; #else int id; CUDA_CHECK(cudaGetDevice(&id)); @@ -5021,12 +5019,14 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); } else { -#ifdef GGML_CUDA_CUBLAS - const bool use_mul_mat_q = false; -#else - const bool use_mul_mat_q = ggml_is_quantized(src0->type); -#endif // GGML_CUDA_CUBLAS - if (use_mul_mat_q) { + int min_compute_capability = INT_MAX; + for (int id = 0; id < g_device_count; ++id) { + if (min_compute_capability > g_compute_capabilities[id]) { + min_compute_capability = g_compute_capabilities[id]; + } + } + + if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); } else { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); @@ -5320,6 +5320,10 @@ void ggml_cuda_set_main_device(int main_device) { } } +void ggml_cuda_set_mul_mat_q(bool mul_mat_q) { + g_mul_mat_q = mul_mat_q; +} + void ggml_cuda_set_scratch_size(size_t scratch_size) { g_scratch_size = scratch_size; } |