aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ggml-metal.m23
-rw-r--r--ggml-metal.metal16
-rw-r--r--llama.cpp3
3 files changed, 31 insertions, 11 deletions
diff --git a/ggml-metal.m b/ggml-metal.m
index d721ac6..0953af6 100644
--- a/ggml-metal.m
+++ b/ggml-metal.m
@@ -47,10 +47,11 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
+ GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(rms_norm);
- GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
@@ -130,10 +131,11 @@ struct ggml_metal_context * ggml_metal_init(void) {
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
+ GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(rms_norm);
- GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
@@ -498,6 +500,14 @@ void ggml_metal_graph_compute(
// use custom matrix x vector kernel
switch (src0t) {
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(ne02 == ne12);
+
+ nth0 = 64;
+ nth1 = 1;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+ } break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
@@ -507,14 +517,6 @@ void ggml_metal_graph_compute(
nth1 = 4;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
- case GGML_TYPE_F16:
- {
- GGML_ASSERT(ne02 == ne12);
-
- nth0 = 32;
- nth1 = 1;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
- } break;
default: GGML_ASSERT(false && "not implemented");
};
@@ -551,6 +553,7 @@ void ggml_metal_graph_compute(
}
switch (src0->type) {
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
default: GGML_ASSERT(false && "not implemented");
}
diff --git a/ggml-metal.metal b/ggml-metal.metal
index 4bedc8e..a359beb 100644
--- a/ggml-metal.metal
+++ b/ggml-metal.metal
@@ -169,6 +169,22 @@ kernel void kernel_diag_mask_inf(
}
}
+kernel void kernel_get_rows_f16(
+ device const void * src0,
+ device const int * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb1,
+ uint tpig[[thread_position_in_grid]]) {
+ const int i = tpig;
+ const int r = ((device int32_t *) src1)[i];
+
+ for (int j = 0; j < ne00; j++) {
+ dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
+ }
+}
+
kernel void kernel_get_rows_q4_0(
device const void * src0,
device const int * src1,
diff --git a/llama.cpp b/llama.cpp
index 70341d0..73f6860 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -961,7 +961,6 @@ static void llama_model_load_internal(
model.hparams = ml->file_loaders.at(0)->hparams;
llama_file_version file_version = ml->file_loaders.at(0)->file_version;
auto & hparams = model.hparams;
- uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
{
switch (hparams.n_layer) {
@@ -975,6 +974,8 @@ static void llama_model_load_internal(
hparams.n_ctx = n_ctx;
}
+ const uint32_t n_ff = ((2*(4*hparams.n_embd)/3 + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
+
{
fprintf(stderr, "%s: format = %s\n", __func__, llama_file_version_name(file_version));
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);