aboutsummaryrefslogtreecommitdiff
path: root/CMakeLists.txt
diff options
context:
space:
mode:
Diffstat (limited to 'CMakeLists.txt')
-rw-r--r--CMakeLists.txt21
1 files changed, 21 insertions, 0 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ed9a3aa..8eadea4 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -66,6 +66,7 @@ endif()
# 3rd party libs
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
+option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
@@ -142,6 +143,26 @@ if (LLAMA_OPENBLAS)
endif()
endif()
+if (LLAMA_CUBLAS)
+ cmake_minimum_required(VERSION 3.17)
+
+ find_package(CUDAToolkit)
+ if (CUDAToolkit_FOUND)
+ message(STATUS "cuBLAS found")
+
+ add_compile_definitions(GGML_USE_CUBLAS)
+
+ if (LLAMA_STATIC)
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ else()
+ set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
+ endif()
+
+ else()
+ message(WARNING "cuBLAS not found")
+ endif()
+endif()
+
if (LLAMA_ALL_WARNINGS)
if (NOT MSVC)
set(c_flags