aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2023-06-19 10:23:56 +0200
committerGitHub <noreply@github.com>2023-06-19 10:23:56 +0200
commit16b9cd193965769089881bb8ec012fccca7b37b6 (patch)
tree2ee329793e782f253966fd81f89ea05f5a1a2495
parentb24c3049d96557c24782e4d32feaae65f47277af (diff)
Convert vector to f16 for dequantize mul mat vec (#1913)
* Convert vector to f16 for dmmv * compile option * Added compilation option description to README * Changed cmake CUDA_ARCHITECTURES from "OFF" to "native"
-rw-r--r--CMakeLists.txt10
-rw-r--r--Makefile3
-rw-r--r--README.md9
-rw-r--r--ggml-cuda.cu202
-rw-r--r--llama.cpp2
5 files changed, 158 insertions, 68 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7367719..dc06365 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -70,6 +70,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
+option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF)
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF)
@@ -238,6 +239,9 @@ if (LLAMA_CUBLAS)
add_compile_definitions(GGML_USE_CUBLAS)
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
+ if (LLAMA_CUDA_DMMV_F16)
+ add_compile_definitions(GGML_CUDA_DMMV_F16)
+ endif()
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
if (LLAMA_STATIC)
@@ -490,13 +494,13 @@ endif()
if (GGML_SOURCES_CUDA)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
- set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES OFF)
+ set_property(TARGET ggml PROPERTY CUDA_ARCHITECTURES "native")
set_property(TARGET ggml PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
- set_property(TARGET ggml_static PROPERTY CUDA_ARCHITECTURES OFF)
+ set_property(TARGET ggml_static PROPERTY CUDA_ARCHITECTURES "native")
set_property(TARGET ggml_static PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
- set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES OFF)
+ set_property(TARGET llama PROPERTY CUDA_ARCHITECTURES "native")
endif()
diff --git a/Makefile b/Makefile
index afd06e0..5dd676f 100644
--- a/Makefile
+++ b/Makefile
@@ -169,6 +169,9 @@ ifdef LLAMA_CUDA_DMMV_Y
else
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
endif # LLAMA_CUDA_DMMV_Y
+ifdef LLAMA_CUDA_DMMV_F16
+ NVCCFLAGS += -DGGML_CUDA_DMMV_F16
+endif # LLAMA_CUDA_DMMV_F16
ifdef LLAMA_CUDA_KQUANTS_ITER
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
else
diff --git a/README.md b/README.md
index e5b3f59..2d05de3 100644
--- a/README.md
+++ b/README.md
@@ -337,7 +337,14 @@ Building the program with BLAS support may lead to some performance improvements
cmake --build . --config Release
```
- The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used.
+ The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance:
+
+ | Option | Legal values | Default | Description |
+ |-------------------------|------------------------|---------|-------------|
+ | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
+ | LLAMA_CUDA_DMMV_Y | Positive integer | 1 | Block size in y direction for the CUDA dequantization + mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
+ | LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. |
+ | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value 2 1 can improve performance for slow GPUs. |
- #### CLBlast
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 16488b9..9ebc57a 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -50,7 +50,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} while (0)
#endif // CUDART_VERSION >= 11
-typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
+#ifdef GGML_CUDA_DMMV_F16
+typedef half dfloat; // dequantize float
+typedef half2 dfloat2;
+#else
+typedef float dfloat; // dequantize float
+typedef float2 dfloat2;
+#endif //GGML_CUDA_DMMV_F16
+
+typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
@@ -234,82 +242,106 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
}
}
-static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
- const float d = x[ib].d;
+ const dfloat d = x[ib].d;
- const uint8_t vui = x[ib].qs[iqs];
+ const int vui = x[ib].qs[iqs];
- const int8_t vi0 = vui & 0xF;
- const int8_t vi1 = vui >> 4;
+ v.x = vui & 0xF;
+ v.y = vui >> 4;
- v0 = (vi0 - 8)*d;
- v1 = (vi1 - 8)*d;
+#ifdef GGML_CUDA_DMMV_F16
+ v = __hsub2(v, {8.0f, 8.0f});
+ v = __hmul2(v, {d, d});
+#else
+ v.x = (v.x - 8.0f) * d;
+ v.y = (v.y - 8.0f) * d;
+#endif // GGML_CUDA_DMMV_F16
}
-static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
- const float d = x[ib].d;
- const float m = x[ib].m;
+ const dfloat d = x[ib].d;
+ const dfloat m = x[ib].m;
- const uint8_t vui = x[ib].qs[iqs];
+ const int vui = x[ib].qs[iqs];
- const int8_t vi0 = vui & 0xF;
- const int8_t vi1 = vui >> 4;
+ v.x = vui & 0xF;
+ v.y = vui >> 4;
- v0 = vi0*d + m;
- v1 = vi1*d + m;
+#ifdef GGML_CUDA_DMMV_F16
+ v = __hmul2(v, {d, d});
+ v = __hadd2(v, {m, m});
+#else
+ v.x = (v.x * d) + m;
+ v.y = (v.y * d) + m;
+#endif // GGML_CUDA_DMMV_F16
}
-static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;
- const float d = x[ib].d;
+ const dfloat d = x[ib].d;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
- const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
- const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
- const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
- const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
- v0 = x0*d;
- v1 = x1*d;
+#ifdef GGML_CUDA_DMMV_F16
+ v = __hsub2(v, {16.0f, 16.0f});
+ v = __hmul2(v, {d, d});
+#else
+ v.x = (v.x - 16.0f) * d;
+ v.y = (v.y - 16.0f) * d;
+#endif // GGML_CUDA_DMMV_F16
}
-static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
- const float d = x[ib].d;
- const float m = x[ib].m;
+ const dfloat d = x[ib].d;
+ const dfloat m = x[ib].m;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
- const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
- const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
+ const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
+ const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
- const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
- const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
+ v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
+ v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
- v0 = x0*d + m;
- v1 = x1*d + m;
+#ifdef GGML_CUDA_DMMV_F16
+ v = __hmul2(v, {d, d});
+ v = __hadd2(v, {m, m});
+#else
+ v.x = (v.x * d) + m;
+ v.y = (v.y * d) + m;
+#endif // GGML_CUDA_DMMV_F16
}
-static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;
- const float d = x[ib].d;
+ const dfloat d = x[ib].d;
- const int8_t vi0 = x[ib].qs[iqs + 0];
- const int8_t vi1 = x[ib].qs[iqs + 1];
+ v.x = x[ib].qs[iqs + 0];
+ v.y = x[ib].qs[iqs + 1];
- v0 = vi0*d;
- v1 = vi1*d;
+#ifdef GGML_CUDA_DMMV_F16
+ v = __hmul2(v, {d, d});
+#else
+ v.x *= d;
+ v.y *= d;
+#endif // GGML_CUDA_DMMV_F16
}
//================================== k-quants
@@ -843,11 +875,12 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
}
}
-static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
+static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;
- v0 = __half2float(x[ib + iqs + 0]);
- v1 = __half2float(x[ib + iqs + 1]);
+ // automatic half -> float type cast if dfloat == float
+ v.x = x[ib + iqs + 0];
+ v.y = x[ib + iqs + 1];
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -864,13 +897,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
- float & v0 = y[iybs + iqs + 0];
- float & v1 = y[iybs + iqs + y_offset];
- dequantize_kernel(vx, ib, iqs, v0, v1);
+ dfloat2 v;
+ dequantize_kernel(vx, ib, iqs, v);
+
+ y[iybs + iqs + 0] = v.x;
+ y[iybs + iqs + y_offset] = v.y;
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
-static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
+static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.y*blockDim.y + threadIdx.y;
@@ -885,7 +920,12 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2;
- float tmp = 0.0f; // partial sum for thread in warp
+// partial sum for each thread
+#ifdef GGML_CUDA_DMMV_F16
+ half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
+#else
+ float tmp = 0.0f;
+#endif // GGML_CUDA_DMMV_F16
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
@@ -899,14 +939,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
// process 2 vals per j iter
// dequantize
- float v0, v1;
- dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
+ dfloat2 v;
+ dequantize_kernel(vx, ib, iqs + j/qr, v);
// matrix multiplication
- tmp += v0 * y[iybs + iqs + j/qr + 0];
- tmp += v1 * y[iybs + iqs + j/qr + y_offset];
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
+#ifdef GGML_CUDA_DMMV_F16
+ tmp += __hmul2(v, {
+ y[iybs + iqs + j/qr + 0],
+ y[iybs + iqs + j/qr + y_offset]
+ });
+#else
+ tmp += v.x * y[iybs + iqs + j/qr + 0];
+ tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+#endif // GGML_CUDA_DMMV_F16
}
}
@@ -918,7 +965,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
}
if (tid == 0) {
+#ifdef GGML_CUDA_DMMV_F16
+ dst[row] = tmp.x + tmp.y;
+#else
dst[row] = tmp;
+#endif // GGML_CUDA_DMMV_F16
}
}
@@ -1213,7 +1264,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
}
-static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1);
@@ -1222,7 +1273,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
-static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1);
@@ -1231,7 +1282,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
-static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1);
@@ -1240,7 +1291,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
-static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1);
@@ -1249,7 +1300,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, f
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
-static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1);
@@ -1299,7 +1350,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
-static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
const dim3 block_nums(1, block_num_y, 1);
@@ -1714,21 +1765,40 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
const int64_t ne00 = src0->ne[0];
const int64_t nrows = i01_high - i01_low;
+// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
+#ifdef GGML_CUDA_DMMV_F16
+ size_t ash;
+ dfloat * src1_dfloat = nullptr; // dfloat == half
+
+ bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
+ src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
+ src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
+
+ if (src1_convert_f16) {
+ src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
+ ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
+ ne00, 1, sizeof(float), 0, 0,
+ ne00, 1, sizeof(half), 0, 0, cudaStream_main);
+ }
+#else
+ dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
+#endif // GGML_CUDA_DMMV_F16
+
switch (src0->type) {
case GGML_TYPE_Q4_0:
- dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q4_1:
- dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q5_0:
- dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q5_1:
- dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q8_0:
- dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
@@ -1746,7 +1816,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_F16:
- convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
+ convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
default:
GGML_ASSERT(false);
@@ -1754,6 +1824,12 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
}
CUDA_CHECK(cudaGetLastError());
+#ifdef GGML_CUDA_DMMV_F16
+ if (src1_convert_f16) {
+ ggml_cuda_pool_free(src1_dfloat, ash);
+ }
+#endif // GGML_CUDA_DMMV_F16
+
(void) src1;
(void) dst;
(void) src0_ddf_i;
diff --git a/llama.cpp b/llama.cpp
index 2105e32..5401db0 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -1620,7 +1620,7 @@ static bool llama_eval_internal(
model.layers[il].w1,
cur);
offload_func(cur);
- ggml_set_name(cur, "result_w2");
+ ggml_set_name(cur, "result_w1");
// SILU activation
cur = ggml_silu(ctx0, cur);