aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorBach Le <bach@bullno1.com>2023-07-15 03:00:58 +0800
committerGitHub <noreply@github.com>2023-07-14 22:00:58 +0300
commit7cdd30bf1f84339c55a5e3de29384f6bbdebb61c (patch)
treebb1c862aebba80b175f213246da5ea771d62da17 /ggml-cuda.cu
parente8035f141e1f71d739fa5cfc9c01531cdee6fc16 (diff)
cuda : allocate all temporary ggml_tensor_extra_gpu from a fixed-size buffer (#2220)
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu24
1 files changed, 22 insertions, 2 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 73cfc55..0646fa7 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -3646,6 +3646,22 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
delete extra;
}
+static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
+static size_t g_temp_tensor_extra_index = 0;
+
+static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
+ if (g_temp_tensor_extras == nullptr) {
+ g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
+ }
+
+ size_t alloc_index = g_temp_tensor_extra_index;
+ g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
+ struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
+ memset(extra, 0, sizeof(*extra));
+
+ return extra;
+}
+
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
if (scratch && g_scratch_size == 0) {
return;
@@ -3663,8 +3679,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
}
tensor->backend = GGML_BACKEND_GPU;
- struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
- memset(extra, 0, sizeof(*extra));
+ struct ggml_tensor_extra_gpu * extra;
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->op == GGML_OP_VIEW ||
@@ -3679,10 +3694,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
if (tensor->op == GGML_OP_VIEW) {
memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
}
+ extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = src0_ddc + offset;
} else if (tensor->op == GGML_OP_CPY) {
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
void * src1_ddv = src1_extra->data_device[g_main_device];
+ extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = src1_ddv;
} else if (scratch) {
GGML_ASSERT(size <= g_scratch_size);
@@ -3695,6 +3712,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
g_scratch_buffer = data;
}
+ extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = data + g_scratch_offset;
g_scratch_offset += size;
@@ -3704,6 +3722,8 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
void * data;
CUDA_CHECK(cudaMalloc(&data, size));
CUDA_CHECK(cudaMemset(data, 0, size));
+ extra = new ggml_tensor_extra_gpu;
+ memset(extra, 0, sizeof(*extra));
extra->data_device[g_main_device] = data;
}