aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-cuda.cu60
1 files changed, 49 insertions, 11 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 6537897..c07b546 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -2423,20 +2423,53 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
-
+#ifdef DEBUG_CUDA_MALLOC
+ int nnz = 0;
+ size_t max_size = 0, tot_size = 0;
+#endif
+ size_t best_diff = 1ull << 36;
+ int ibest = -1;
for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
cuda_buffer& b = g_cuda_buffer_pool[id][i];
- if (b.size >= size && b.ptr != nullptr) {
- void * ptr = b.ptr;
- *actual_size = b.size;
- b.ptr = nullptr;
- b.size = 0;
- return ptr;
+ if (b.ptr != nullptr) {
+#ifdef DEBUG_CUDA_MALLOC
+ ++nnz;
+ tot_size += b.size;
+ if (b.size > max_size) max_size = b.size;
+#endif
+ if (b.size >= size) {
+ size_t diff = b.size - size;
+ if (diff < best_diff) {
+ best_diff = diff;
+ ibest = i;
+ if (!best_diff) {
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+ }
+ }
}
}
+ if (ibest >= 0) {
+ cuda_buffer& b = g_cuda_buffer_pool[id][ibest];
+ void * ptr = b.ptr;
+ *actual_size = b.size;
+ b.ptr = nullptr;
+ b.size = 0;
+ return ptr;
+ }
+#ifdef DEBUG_CUDA_MALLOC
+ fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
+ (uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
+#endif
void * ptr;
- CUDA_CHECK(cudaMalloc((void **) &ptr, size));
- *actual_size = size;
+ size_t look_ahead_size = (size_t) (1.05 * size);
+ look_ahead_size = 256 * ((look_ahead_size + 255)/256);
+ CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
+ *actual_size = look_ahead_size;
return ptr;
}
@@ -2955,8 +2988,13 @@ inline void ggml_cuda_op_rope(
const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3];
- const float theta_scale = powf(10000.0, -2.0f/n_dims);
- const float p = ((mode & 1) == 0 ? n_past + i02 : i02);
+ // RoPE alteration for extended context
+ float freq_base, freq_scale;
+ memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
+ memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
+
+ const float theta_scale = powf(freq_base, -2.0f/n_dims);
+ const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale;
bool is_glm = mode & 4;