diff options
author | Georgi Gerganov <ggerganov@gmail.com> | 2023-07-24 14:46:21 +0300 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-24 14:46:21 +0300 |
commit | 5b2b2dc6ae8086bff7c9b3c17fb435cf319b7185 (patch) | |
tree | db5cf0288472b4fc0ef88217bfcbddf4d18c2a03 /ggml-cuda.cu | |
parent | 42f70cb2f6a8089e0a0560a459e4ba317bac4d49 (diff) |
ggml : sync (unary ops refactor, static-correctness) (#2370)
* ggml : sync (unary ops, tests)
ggml-ci
* tests : remove unnecessary funcs
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r-- | ggml-cuda.cu | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6823adc..b8c9835 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -3962,18 +3962,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_mul; break; - case GGML_OP_GELU: - if (!any_on_device) { - return false; - } - func = ggml_cuda_gelu; - break; - case GGML_OP_SILU: - if (!any_on_device) { - return false; - } - func = ggml_cuda_silu; - break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_GELU: + if (!any_on_device) { + return false; + } + func = ggml_cuda_gelu; + break; + case GGML_UNARY_OP_SILU: + if (!any_on_device) { + return false; + } + func = ggml_cuda_silu; + break; + default: + return false; + } break; case GGML_OP_NORM: if (!any_on_device) { return false; |