diff options
author | Howard Su <howard0su@gmail.com> | 2023-07-13 21:58:09 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-13 21:58:09 +0800 |
commit | ff5d58faecf1f02b05bd015bdfc6a394cf2bc9ba (patch) | |
tree | 8d5b87af7cfa839e9ba1bcd58e434e417bce631d | |
parent | b782422a3e090d0aeab84bfa03ba008dcd1c2a3d (diff) |
Fix compile error on Windows CUDA (#2207)
-rw-r--r-- | ggml-cuda.cu | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dc4b773..e0d5e91 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -267,10 +267,9 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co dst[i] = x[i] * y[i%ky]; } -static const float GELU_COEF_A = 0.044715f; -static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; - static __global__ void gelu_f32(const float * x, float * dst, const int k) { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { @@ -2300,7 +2299,7 @@ inline void ggml_cuda_op_add( const int64_t ne00 = src0->ne[0]; const int64_t i01_diff = i01_high - i01_low; - const int64_t ne10 = src1->ne[0]; + // const int64_t ne10 = src1->ne[0]; // compute if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |