aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
diff options
context:
space:
mode:
authorGeorgi Gerganov <ggerganov@gmail.com>2023-07-24 14:46:21 +0300
committerGitHub <noreply@github.com>2023-07-24 14:46:21 +0300
commit5b2b2dc6ae8086bff7c9b3c17fb435cf319b7185 (patch)
treedb5cf0288472b4fc0ef88217bfcbddf4d18c2a03 /ggml-cuda.cu
parent42f70cb2f6a8089e0a0560a459e4ba317bac4d49 (diff)
ggml : sync (unary ops refactor, static-correctness) (#2370)
* ggml : sync (unary ops, tests) ggml-ci * tests : remove unnecessary funcs
Diffstat (limited to 'ggml-cuda.cu')
-rw-r--r--ggml-cuda.cu29
1 files changed, 17 insertions, 12 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 6823adc..b8c9835 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -3962,18 +3962,23 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
}
func = ggml_cuda_mul;
break;
- case GGML_OP_GELU:
- if (!any_on_device) {
- return false;
- }
- func = ggml_cuda_gelu;
- break;
- case GGML_OP_SILU:
- if (!any_on_device) {
- return false;
- }
- func = ggml_cuda_silu;
- break;
+ case GGML_OP_UNARY:
+ switch (ggml_get_unary_op(tensor)) {
+ case GGML_UNARY_OP_GELU:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_gelu;
+ break;
+ case GGML_UNARY_OP_SILU:
+ if (!any_on_device) {
+ return false;
+ }
+ func = ggml_cuda_silu;
+ break;
+ default:
+ return false;
+ } break;
case GGML_OP_NORM:
if (!any_on_device) {
return false;