aboutsummaryrefslogtreecommitdiff
path: root/ggml-metal.m
diff options
context:
space:
mode:
Diffstat (limited to 'ggml-metal.m')
-rw-r--r--ggml-metal.m11
1 files changed, 9 insertions, 2 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index fd69c41..3f15f79 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -25,6 +25,8 @@ struct ggml_metal_buffer {
};
struct ggml_metal_context {
+ int n_cb;
+
float * logits;
id<MTLDevice> device;
@@ -86,11 +88,12 @@ static NSString * const msl_library_source = @"see metal.metal";
@implementation GGMLMetalClass
@end
-struct ggml_metal_context * ggml_metal_init(void) {
+struct ggml_metal_context * ggml_metal_init(int n_cb) {
fprintf(stderr, "%s: allocating\n", __func__);
struct ggml_metal_context * ctx = malloc(sizeof(struct ggml_metal_context));
+ ctx->n_cb = n_cb;
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
@@ -208,6 +211,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
free(ctx);
}
+void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
+ ctx->n_cb = n_cb;
+}
+
// finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
@@ -354,7 +361,7 @@ void ggml_metal_graph_compute(
// create multiple command buffers and enqueue them
// then, we encode the graph into the command buffers in parallel
- const int n_cb = gf->n_threads;
+ const int n_cb = ctx->n_cb;
NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];