aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-08-07 10:52:57 +0300
committerGitHub <noreply@github.com>2023-08-07 10:52:57 +0300
commitf6f9896ac3d2ff207e18f87dab85d126ceef5236 (patch)
tree189c7c5688f33267ff5b88fdcab6028c4eeb2881
parent34a14b28ff7f3c98730339bacee035091b2a812a (diff)
metal : fix out-of-bounds access + inc concurrency nodes (#2416)
* metal : fix out-of-bounds access + style changes * metal : increase concurrency nodes to 2*GGML_MAX_NODES
-rw-r--r--ggml-metal.m57
1 files changed, 39 insertions, 18 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index 3f098d3..b47a98e 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -7,6 +7,11 @@
#import <Metal/Metal.h>
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
+#undef MIN
+#undef MAX
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+
#ifdef GGML_METAL_NDEBUG
#define metal_printf(...)
#else
@@ -15,6 +20,8 @@
#define UNUSED(x) (void)(x)
+#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
+
struct ggml_metal_buffer {
const char * name;
@@ -36,7 +43,7 @@ struct ggml_metal_context {
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
- int concur_list[GGML_MAX_NODES];
+ int concur_list[GGML_MAX_CONCUR];
int concur_list_len;
// custom kernels
@@ -370,15 +377,15 @@ void ggml_metal_graph_find_concurrency(
struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) {
int search_depth = gf->n_nodes; //we only find concurrency in this range to avoid wasting too much time
- int nodes_unused[GGML_MAX_NODES];
+ int nodes_unused[GGML_MAX_CONCUR];
- for (int i = 0; i < GGML_MAX_NODES; i++) {ctx->concur_list[i] = 0;}
- for (int i = 0; i < gf->n_nodes; i++) {nodes_unused[i] = 1;}
+ for (int i = 0; i < GGML_MAX_CONCUR; i++) { ctx->concur_list[i] = 0; }
+ for (int i = 0; i < gf->n_nodes; i++) { nodes_unused[i] = 1; }
ctx->concur_list_len = 0;
- int n_left = gf->n_nodes;
- int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
- int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
+ int n_left = gf->n_nodes;
+ int n_start = 0; // all nodes before n_start at nodes_unused array have been sorted and store back to ctx->concur_list
+ int level_pos = 0; // at ctx->concur_list, the last layer (level) ends at level_pos
while (n_left > 0) {
// number of nodes at a layer (that can be issued concurrently)
@@ -386,28 +393,40 @@ void ggml_metal_graph_find_concurrency(
for (int i = n_start; i < ((n_start + search_depth > gf->n_nodes) ? gf->n_nodes : n_start + search_depth); i++) {
if (nodes_unused[i]) {
// if the requirements for gf->nodes[i] are satisfied
- int exe_flag=1;
+ int exe_flag = 1;
+
// scan all srcs
for (int src_ind = 0; src_ind < GGML_MAX_SRC; src_ind++) {
struct ggml_tensor * src_cur = gf->nodes[i]->src[src_ind];
if (src_cur) {
// if is leaf nodes it's satisfied.
- if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {continue;}
+ // TODO: ggml_is_leaf()
+ if (src_cur->op == GGML_OP_NONE && src_cur->grad == NULL) {
+ continue;
+ }
// otherwise this src should be the output from previous nodes.
int is_found = 0;
+
// scan 2*search_depth back because we inserted barrier.
- for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
- if (gf->nodes[ctx->concur_list[j]] == src_cur) {is_found = 1; break;}
+ //for (int j = ((level_pos - 2*search_depth) < 0 ? 0 : (level_pos - 2*search_depth)); j < level_pos; j++) {
+ for (int j = MAX(0, level_pos - 2*search_depth); j < level_pos; j++) {
+ if (ctx->concur_list[j] >= 0 && gf->nodes[ctx->concur_list[j]] == src_cur) {
+ is_found = 1;
+ break;
+ }
+ }
+ if (is_found == 0) {
+ exe_flag = 0;
+ break;
}
- if (is_found == 0) {exe_flag = 0; break;}
}
}
if (exe_flag) {
// check if nodes[i]'s data will be overwritten by a node before nodes[i].
// if node[5] and node[3] write to the same memory region, then we can't issue node[5] before node[3]
int64_t data_start = (int64_t) gf->nodes[i]->data;
- int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
+ int64_t length = (int64_t) ggml_nbytes(gf->nodes[i]);
for (int j = n_start; j < i; j++) {
if (nodes_unused[j] && gf->nodes[j]->op != GGML_OP_RESHAPE \
&& gf->nodes[j]->op != GGML_OP_VIEW \
@@ -416,9 +435,9 @@ void ggml_metal_graph_find_concurrency(
if (((int64_t)gf->nodes[j]->data) >= data_start + length || \
((int64_t)gf->nodes[j]->data) + (int64_t) ggml_nbytes(gf->nodes[j]) <= data_start) {
continue;
- } else {
- exe_flag = 0;
}
+
+ exe_flag = 0;
}
}
}
@@ -435,11 +454,13 @@ void ggml_metal_graph_find_concurrency(
ctx->concur_list[level_pos + concurrency] = -1;
ctx->concur_list_len++;
// jump all sorted nodes at nodes_bak
- while (!nodes_unused[n_start]) {n_start++;}
+ while (!nodes_unused[n_start]) {
+ n_start++;
+ }
level_pos += concurrency + 1;
}
- if (ctx->concur_list_len > GGML_MAX_NODES) {
+ if (ctx->concur_list_len > GGML_MAX_CONCUR) {
fprintf(stderr, "%s: too many elements for metal ctx->concur_list!\n", __func__);
}
}
@@ -453,7 +474,7 @@ void ggml_metal_graph_compute(
// else fallback to serial dispatch
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
- const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_NODES;
+ const bool has_concur = ctx->concur_list_len && ctx->concur_list_len <= GGML_MAX_CONCUR;
const int n_nodes = has_concur ? ctx->concur_list_len : gf->n_nodes;
edesc.dispatchType = has_concur ? MTLDispatchTypeConcurrent : MTLDispatchTypeSerial;