aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohannes Gäßler <johannesg@5d6.de>2023-08-09 09:42:34 +0200
committerGitHub <noreply@github.com>2023-08-09 09:42:34 +0200
commit25d43e0eb578b6e73046d9d6644a3a14d460600d (patch)
treefddb8e9a044ce7eda09024e345a871cdada4cac8
parentf5bfea0580e417f99850d5456ca541d871a3e48c (diff)
CUDA: tuned mul_mat_q kernels (#2546)
-rw-r--r--Makefile5
-rw-r--r--README.md1
-rw-r--r--ggml-cuda.cu1068
3 files changed, 682 insertions, 392 deletions
diff --git a/Makefile b/Makefile
index 32598ed..f01bf0c 100644
--- a/Makefile
+++ b/Makefile
@@ -253,11 +253,6 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
else
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
endif
-ifdef LLAMA_CUDA_MMQ_Y
- NVCCFLAGS += -DGGML_CUDA_MMQ_Y=$(LLAMA_CUDA_MMQ_Y)
-else
- NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64
-endif # LLAMA_CUDA_MMQ_Y
#ifdef LLAMA_CUDA_CUBLAS
# NVCCFLAGS += -DGGML_CUDA_CUBLAS
#endif # LLAMA_CUDA_CUBLAS
diff --git a/README.md b/README.md
index 2ece294..6900b11 100644
--- a/README.md
+++ b/README.md
@@ -406,7 +406,6 @@ Building the program with BLAS support may lead to some performance improvements
--->
| Option | Legal values | Default | Description |
|-------------------------|------------------------|---------|-------------|
- | LLAMA_CUDA_MMQ_Y | Positive integer >= 32 | 64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. |
| LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. |
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 9d42efb..6390b11 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -14,6 +14,7 @@
#include "ggml.h"
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define CC_TURING 700
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
@@ -262,10 +263,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
-#ifndef GGML_CUDA_MMQ_Y
-#define GGML_CUDA_MMQ_Y 64
-#endif // GGML_CUDA_MMQ_Y
-
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
@@ -285,6 +282,20 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
};
+static int g_device_count = -1;
+static int g_main_device = 0;
+static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
+static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+static bool g_mul_mat_q = false;
+
+static void * g_scratch_buffer = nullptr;
+static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
+static size_t g_scratch_offset = 0;
+
+static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
+
static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -1549,8 +1560,8 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
#else
const float2 dm8f = __half22float2(dm8);
const float2 ds8f = __half22float2(ds8);
- const float d8d8 = dm8.x * ds8.x;
- const float m8s8 = dm8.y * ds8.y;
+ const float d8d8 = dm8f.x * ds8f.x;
+ const float m8s8 = dm8f.y * ds8f.y;
#endif // GGML_CUDA_F16
// scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
@@ -1884,21 +1895,21 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMVQ>(v, u, bq4_0->d, bq8_1->ds);
}
-static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0) + GGML_CUDA_MMQ_Y/QI4_0];
+ __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
+ __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0];
*x_ql = tile_x_qs;
*x_dm = (half2 *) tile_x_d;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -1910,7 +1921,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
float * x_dmf = (float *) x_dm;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -1920,39 +1931,30 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
- x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
+ // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d;
}
-// const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
-// const int kbxd = k % blocks_per_tile_x_row;
+ const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
+ const int kbxd = k % blocks_per_tile_x_row;
-// #pragma unroll
-// for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_0) {
-// FIXME out-of-bounds
-// const int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) {
+ int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
-// if (i >= GGML_CUDA_MMQ_Y) {
-// return;
-// }
+ if (need_check) {
+ i = min(i, i_max);
+ }
-// const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+ const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-// x_dm[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd].x = bxi->d;
-// }
+ x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = bxi->d;
+ }
}
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q4_0_Q8_1_MMQ == 0);
-
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const float * x_dmf = (float *) x_dm;
@@ -1960,13 +1962,13 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
- u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_0];
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE];
}
return vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0],
- y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
@@ -1987,21 +1989,21 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMVQ>(v, u, bq4_1->dm, bq8_1->ds);
}
-static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) + + GGML_CUDA_MMQ_Y];
- __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1) + GGML_CUDA_MMQ_Y/QI4_1];
+ __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y];
+ __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1];
*x_ql = tile_x_qs;
*x_dm = tile_x_dm;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2011,7 +2013,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
const block_q4_1 * bx0 = (block_q4_1 *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -2027,7 +2029,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_1) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) {
int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
if (need_check) {
@@ -2044,27 +2046,19 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q4_1_Q8_1_MMQ == 0);
-
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
int u[2*VDR_Q4_1_Q8_1_MMQ];
#pragma unroll
for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
- u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI4_1];
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE];
}
return vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ>
(&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1],
- y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+ y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
@@ -2087,21 +2081,21 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
return vec_dot_q5_0_q8_1_impl<VDR_Q5_0_Q8_1_MMVQ>(vl, vh, u, bq5_0->d, bq8_1->ds);
}
-static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
+ __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
+ __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0];
*x_ql = tile_x_ql;
*x_dm = (half2 *) tile_x_d;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2111,7 +2105,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
const block_q5_0 * bx0 = (block_q5_0 *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -2147,7 +2141,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
float * x_dmf = (float *) x_dm;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_0) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) {
int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
if (need_check) {
@@ -2164,14 +2158,6 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q5_0_Q8_1_MMQ == 0);
-
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
const float * x_dmf = (const float *) x_dm;
@@ -2181,12 +2167,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
#pragma unroll
for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
- u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_0];
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE];
}
return vec_dot_q8_0_q8_1_impl<QR5_0*VDR_Q5_0_Q8_1_MMQ>
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
@@ -2209,21 +2195,21 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
return vec_dot_q5_1_q8_1_impl<VDR_Q5_1_Q8_1_MMVQ>(vl, vh, u, bq5_1->dm, bq8_1->ds);
}
-static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1) + GGML_CUDA_MMQ_Y/QI5_1];
+ __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
+ __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2233,7 +2219,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
const block_q5_1 * bx0 = (block_q5_1 *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -2266,7 +2252,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_1) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) {
int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
if (need_check) {
@@ -2283,14 +2269,6 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q5_1_Q8_1_MMQ == 0);
-
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
@@ -2298,12 +2276,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
#pragma unroll
for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) {
- u[2*l+0] = y_qs[j * (2*WARP_SIZE) + kyqs + l];
- u[2*l+1] = y_qs[j * (2*WARP_SIZE) + kyqs + l + QI5_1];
+ u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE];
+ u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE];
}
return vec_dot_q8_1_q8_1_impl<QR5_1*VDR_Q5_1_Q8_1_MMQ>
- (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
+ (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]);
}
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
@@ -2323,21 +2301,21 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x);
}
-static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ float tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0) + GGML_CUDA_MMQ_Y/QI8_0];
+ __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y];
+ __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0];
*x_ql = tile_x_qs;
*x_dm = (half2 *) tile_x_d;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2348,7 +2326,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_
const block_q8_0 * bx0 = (block_q8_0 *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -2358,41 +2336,29 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q8_
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
- x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbx] = bxi->d;
}
-// const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
-// const int kbxd = k % blocks_per_tile_x_row;
+ const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
+ const int kbxd = k % blocks_per_tile_x_row;
-// #pragma unroll
-// for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI8_0) {
-// FIXME out-of-bounds
-// const int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0) {
+ int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
-// #if GGML_CUDA_MMQ_Y < 64
-// if (i >= GGML_CUDA_MMQ_Y) {
-// return;
-// }
-// #endif // GGML_CUDA_MMQ_Y < 64
+ if (need_check) {
+ i = min(i, i_max);
+ }
-// const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
+ const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
-// x_dm[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd].x = bxi->d;
-// }
+ x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = bxi->d;
+ }
}
static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q8_0_Q8_1_MMQ == 0);
-
const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds;
@@ -2424,23 +2390,23 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
}
-static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI2_K) + GGML_CUDA_MMQ_Y/QI2_K];
- __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4) + GGML_CUDA_MMQ_Y/4];
+ __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
+ __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K];
+ __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2450,7 +2416,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_
const block_q2_K * bx0 = (block_q2_K *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -2466,8 +2432,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_
const int kbxd = k % blocks_per_tile_x_row;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI2_K) {
- int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI2_K) {
+ int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -2479,7 +2445,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q2_
}
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
if (need_check) {
@@ -2496,14 +2462,6 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q2_K_Q8_1_MMQ == 0);
-
const int kbx = k / QI2_K;
const int ky = (k % QI2_K) * QR2_K;
const float * y_df = (const float *) y_ds;
@@ -2520,7 +2478,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4;
- const int index_y = j * (QR2_K*WARP_SIZE) + QR2_K*k;
+ const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE;
return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]);
}
@@ -2551,12 +2509,12 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
}
-static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI3_K) + GGML_CUDA_MMQ_Y/QI3_K];
- __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2) + GGML_CUDA_MMQ_Y/2];
- __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4) + GGML_CUDA_MMQ_Y/4];
+ __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
+ __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K];
+ __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2];
+ __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
@@ -2564,12 +2522,12 @@ static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 **
*x_sc = tile_x_sc;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2579,7 +2537,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
const block_q3_K * bx0 = (block_q3_K *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -2596,8 +2554,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
float * x_dmf = (float *) x_dm;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI3_K) {
- int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI3_K) {
+ int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -2609,7 +2567,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
}
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) {
int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
if (need_check) {
@@ -2623,7 +2581,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q3_
}
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
if (need_check) {
@@ -2652,14 +2610,6 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q3_K_Q8_1_MMQ == 0);
-
const int kbx = k / QI3_K;
const int ky = (k % QI3_K) * QR3_K;
const float * x_dmf = (const float *) x_dm;
@@ -2681,7 +2631,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
v[l] = __vsubss4(vll, vlh);
}
- const int index_y = j * (QR3_K*WARP_SIZE) + k*QR3_K;
+ const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE;
return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]);
}
@@ -2778,23 +2728,23 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
#endif
}
-static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K) + GGML_CUDA_MMQ_Y/QI4_K];
- __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8];
+ __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y];
+ __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K];
+ __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2804,7 +2754,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
const block_q4_K * bx0 = (block_q4_K *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -2820,8 +2770,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_K) {
- int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) {
+ int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -2833,8 +2783,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q4_
}
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
- int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -2858,14 +2808,6 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q4_K_Q8_1_MMQ == 0);
-
int v[QR4_K*VDR_Q4_K_Q8_1_MMQ];
#pragma unroll
@@ -2876,7 +2818,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8);
- const int index_y = j * (QR4_K*WARP_SIZE) + QR4_K*k;
+ const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE;
return vec_dot_q4_K_q8_1_impl_mmq(v, &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]);
}
@@ -2969,23 +2911,23 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
#endif
}
-static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_K) + GGML_CUDA_MMQ_Y/QI5_K];
- __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8];
+ __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
+ __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K];
+ __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -2995,7 +2937,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
const block_q5_K * bx0 = (block_q5_K *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -3024,8 +2966,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_K) {
- int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) {
+ int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -3037,8 +2979,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q5_
}
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
- int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -3062,18 +3004,10 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q5_K_Q8_1_MMQ == 0);
-
const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8);
- const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
- const int index_y = j * (QR5_K*WARP_SIZE) + QR5_K*k;
+ const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k;
+ const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE;
return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]);
}
@@ -3103,23 +3037,23 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
}
-static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
+template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
- __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE) + GGML_CUDA_MMQ_Y];
- __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K) + GGML_CUDA_MMQ_Y/QI6_K];
- __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8];
+ __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y];
+ __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K];
+ __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8];
*x_ql = tile_x_ql;
*x_dm = tile_x_dm;
*x_sc = tile_x_sc;
}
-template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
+template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) {
__builtin_assume(i_offset >= 0);
- __builtin_assume(i_offset < 8);
+ __builtin_assume(i_offset < nwarps);
__builtin_assume(k >= 0);
__builtin_assume(k < WARP_SIZE);
@@ -3129,7 +3063,7 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
const block_q6_K * bx0 = (block_q6_K *) vx;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
int i = i0 + i_offset;
if (need_check) {
@@ -3159,8 +3093,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
float * x_dmf = (float *) x_dm;
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI6_K) {
- int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) {
+ int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -3172,8 +3106,8 @@ template <bool need_check> static __device__ __forceinline__ void load_tiles_q6_
}
#pragma unroll
- for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
- int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) {
+ int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y;
if (need_check) {
i = min(i, i_max);
@@ -3189,25 +3123,17 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
- __builtin_assume(i >= 0);
- __builtin_assume(i < GGML_CUDA_MMQ_Y);
- __builtin_assume(j >= 0);
- __builtin_assume(j < WARP_SIZE);
- __builtin_assume(k >= 0);
- __builtin_assume(k < WARP_SIZE);
- __builtin_assume(k % VDR_Q6_K_Q8_1_MMQ == 0);
-
const float * x_dmf = (const float *) x_dm;
const float * y_df = (const float *) y_ds;
const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]);
- const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
- const int index_y = j * (QR6_K*WARP_SIZE) + QR6_K*k;
+ const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k;
+ const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE;
return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
}
-template <int qk, int qr, int qi, bool need_sum, typename block_q_t,
+template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
static __global__ void mul_mat_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
@@ -3222,14 +3148,11 @@ static __global__ void mul_mat_q(
const int & ncols_dst = ncols_y;
- const int tid_x = threadIdx.x;
- const int tid_y = threadIdx.y;
-
- const int row_dst_0 = blockIdx.x*GGML_CUDA_MMQ_Y;
+ const int row_dst_0 = blockIdx.x*mmq_y;
const int & row_x_0 = row_dst_0;
- const int row_dst = row_dst_0 + tid_x;
+ const int row_dst = row_dst_0 + threadIdx.x;
- const int col_dst_0 = blockIdx.y*WARP_SIZE;
+ const int col_dst_0 = blockIdx.y*mmq_x;
const int & col_y_0 = col_dst_0;
int * tile_x_ql = nullptr;
@@ -3239,64 +3162,65 @@ static __global__ void mul_mat_q(
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
- const int blocks_per_tile_y_col = qr*WARP_SIZE/QI8_1;
-
- __shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)];
- __shared__ half2 tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col];
+ __shared__ int tile_y_qs[mmq_x * WARP_SIZE];
+ __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
- float sum[GGML_CUDA_MMQ_Y/WARP_SIZE][4] = {0.0f};
+ float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f};
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
- tid_y, nrows_x-row_x_0-1, tid_x, blocks_per_row_x);
+ threadIdx.y, nrows_x-row_x_0-1, threadIdx.x, blocks_per_row_x);
+#pragma unroll
for (int ir = 0; ir < qr; ++ir) {
- const int kqs = ir*WARP_SIZE + tid_x;
+ const int kqs = ir*WARP_SIZE + threadIdx.x;
const int kbxd = kqs / QI8_1;
- for (int i = 0; i < WARP_SIZE; i += 8) {
- const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
+#pragma unroll
+ for (int i = 0; i < mmq_x; i += nwarps) {
+ const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
- tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1);
+ const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
+ tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
}
- }
- for (int ids0 = 0; ids0 < WARP_SIZE; ids0 += 8 * (WARP_SIZE/blocks_per_tile_y_col)) {
- const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE;
- const int kby = tid_x % blocks_per_tile_y_col;
- const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
-
- // if the sum is not needed it's faster to transform the scale to f32 ahead of time
- const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds;
- half2 * dsi_dst = &tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby];
- if (need_sum) {
- *dsi_dst = *dsi_src;
- } else {
- float * dfi_dst = (float *) dsi_dst;
- *dfi_dst = (*dsi_src).x;
+#pragma unroll
+ for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
+ const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
+ const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
+ const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
+
+ // if the sum is not needed it's faster to transform the scale to f32 ahead of time
+ const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
+ half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
+ if (need_sum) {
+ *dsi_dst = *dsi_src;
+ } else {
+ float * dfi_dst = (float *) dsi_dst;
+ *dfi_dst = (*dsi_src).x;
+ }
}
- }
- __syncthreads();
+ __syncthreads();
-#if __CUDA_ARCH__ >= 700 // Unrolling the loop is slower on Pascal
+// #pragma unroll // unrolling this loop causes too much register pressure
+ for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
#pragma unroll
-#endif // __CUDA_ARCH__ >= 700
- for (int k = 0; k < WARP_SIZE; k += vdr) {
+ for (int j = 0; j < mmq_x; j += nwarps) {
#pragma unroll
- for (int j = 0; j < WARP_SIZE; j += 8) {
-#pragma unroll
- for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
- sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
- tid_x + i, tid_y + j, k);
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+ sum[i/WARP_SIZE][j/nwarps] += vec_dot(
+ tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
+ threadIdx.x + i, threadIdx.y + j, k);
+ }
}
}
- }
- __syncthreads();
+ __syncthreads();
+ }
}
@@ -3304,15 +3228,15 @@ static __global__ void mul_mat_q(
return;
}
- for (int j = 0; j < WARP_SIZE; j += 8) {
- const int col_dst = col_dst_0 + j + tid_y;
+ for (int j = 0; j < mmq_x; j += nwarps) {
+ const int col_dst = col_dst_0 + j + threadIdx.y;
if (col_dst >= ncols_dst) {
return;
}
- for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
- dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/8];
+ for (int i = 0; i < mmq_y; i += WARP_SIZE) {
+ dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/nwarps];
}
}
}
@@ -4014,17 +3938,52 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<false>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 64;
+ const int mmq_y = 128;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+ load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+ load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, allocate_tiles_q4_0, load_tiles_q4_0<true>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+ load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK4_0, QR4_0, QI4_0, true, block_q4_0, mmq_x, mmq_y, nwarps, allocate_tiles_q4_0<mmq_y>,
+ load_tiles_q4_0<mmq_y, nwarps, need_check>, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4032,17 +3991,53 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<false>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 64;
+ const int mmq_y = 128;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+ load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+ load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, allocate_tiles_q4_1, load_tiles_q4_1<true>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+ load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK4_1, QR4_1, QI4_1, true, block_q4_1, mmq_x, mmq_y, nwarps, allocate_tiles_q4_1<mmq_y>,
+ load_tiles_q4_1<mmq_y, nwarps, need_check>, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
+
}
}
@@ -4050,17 +4045,52 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<false>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 128;
+ const int mmq_y = 64;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+ load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+ load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, allocate_tiles_q5_0, load_tiles_q5_0<true>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+ load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK5_0, QR5_0, QI5_0, false, block_q5_0, mmq_x, mmq_y, nwarps, allocate_tiles_q5_0<mmq_y>,
+ load_tiles_q5_0<mmq_y, nwarps, need_check>, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4068,17 +4098,52 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<false>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 128;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+ load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+ load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, allocate_tiles_q5_1, load_tiles_q5_1<true>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+ load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK5_1, QR5_1, QI5_1, true, block_q5_1, mmq_x, mmq_y, nwarps, allocate_tiles_q5_1<mmq_y>,
+ load_tiles_q5_1<mmq_y, nwarps, need_check>, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4086,17 +4151,52 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<false>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 128;
+ const int mmq_y = 64;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+ load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+ load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, allocate_tiles_q8_0, load_tiles_q8_0<true>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+ load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK8_0, QR8_0, QI8_0, false, block_q8_0, mmq_x, mmq_y, nwarps, allocate_tiles_q8_0<mmq_y>,
+ load_tiles_q8_0<mmq_y, nwarps, need_check>, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4104,17 +4204,52 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<false>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 64;
+ const int mmq_y = 128;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+ load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+ load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, allocate_tiles_q2_K, load_tiles_q2_K<true>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+ load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR2_K, QI2_K, false, block_q2_K, mmq_x, mmq_y, nwarps, allocate_tiles_q2_K<mmq_y>,
+ load_tiles_q2_K<mmq_y, nwarps, need_check>, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4122,17 +4257,52 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<false>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 128;
+ const int mmq_y = 128;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+ load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+ load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, allocate_tiles_q3_K, load_tiles_q3_K<true>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+ load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR3_K, QI3_K, false, block_q3_K, mmq_x, mmq_y, nwarps, allocate_tiles_q3_K<mmq_y>,
+ load_tiles_q3_K<mmq_y, nwarps, need_check>, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4140,17 +4310,52 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<false>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 64;
+ const int mmq_y = 128;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+ load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+ load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, allocate_tiles_q4_K, load_tiles_q4_K<true>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 32;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+ load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR4_K, QI4_K, true, block_q4_K, mmq_x, mmq_y, nwarps, allocate_tiles_q4_K<mmq_y>,
+ load_tiles_q4_K<mmq_y, nwarps, need_check>, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4158,17 +4363,52 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<false>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 64;
+ const int mmq_y = 128;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+ load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+ load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, allocate_tiles_q5_K, load_tiles_q5_K<true>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+ load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR5_K, QI5_K, true, block_q5_K, mmq_x, mmq_y, nwarps, allocate_tiles_q5_K<mmq_y>,
+ load_tiles_q5_K<mmq_y, nwarps, need_check>, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4176,17 +4416,52 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
- const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
- const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
- const dim3 block_nums(block_num_x, block_num_y, 1);
- const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
-
- if (nrows_x % GGML_CUDA_MMQ_Y == 0) {
- mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<false>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ int id;
+ CUDA_CHECK(cudaGetDevice(&id));
+ const int compute_capability = g_compute_capabilities[id];
+
+ if (compute_capability >= CC_TURING) {
+ const int mmq_x = 64;
+ const int mmq_y = 64;
+ const int nwarps = 4;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+ load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+ load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
} else {
- mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, allocate_tiles_q6_K, load_tiles_q6_K<true>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
- <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ const int mmq_x = 32;
+ const int mmq_y = 64;
+ const int nwarps = 8;
+
+ const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
+ const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
+ const dim3 block_nums(block_num_x, block_num_y, 1);
+ const dim3 block_dims(WARP_SIZE, nwarps, 1);
+
+ if (nrows_x % mmq_y == 0) {
+ const bool need_check = false;
+ mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+ load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ } else {
+ const bool need_check = true;
+ mul_mat_q<QK_K, QR6_K, QI6_K, false, block_q6_K, mmq_x, mmq_y, nwarps, allocate_tiles_q6_K<mmq_y>,
+ load_tiles_q6_K<mmq_y, nwarps, need_check>, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat>
+ <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
+ }
}
}
@@ -4361,20 +4636,6 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
}
-static void * g_scratch_buffer = nullptr;
-static size_t g_scratch_size = 1024*1024*1024; // 1 GB by default
-static size_t g_scratch_offset = 0;
-
-static int g_device_count = -1;
-static int g_main_device = 0;
-static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
-static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
-static bool g_mul_mat_q = false;
-
-static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
-
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES] = { nullptr };
-
void ggml_init_cublas() {
static bool initialized = false;
@@ -4730,6 +4991,37 @@ inline void ggml_cuda_op_mul_mat_q(
(void) i1;
}
+static int64_t get_row_rounding(ggml_type type) {
+ int max_compute_capability = INT_MIN;
+ for (int id = 0; id < g_device_count; ++id) {
+ if (max_compute_capability < g_compute_capabilities[id]
+ && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
+ max_compute_capability = g_compute_capabilities[id];
+ }
+ }
+
+ switch(type) {
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return max_compute_capability >= CC_TURING ? 128 : 64;
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return 64;
+ case GGML_TYPE_F16:
+ return 1;
+ case GGML_TYPE_Q2_K:
+ case GGML_TYPE_Q3_K:
+ case GGML_TYPE_Q4_K:
+ case GGML_TYPE_Q5_K:
+ return max_compute_capability >= CC_TURING ? 128 : 64;
+ case GGML_TYPE_Q6_K:
+ return 64;
+ default:
+ GGML_ASSERT(false);
+ }
+}
+
inline void ggml_cuda_op_mul_mat_vec(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -5130,14 +5422,16 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
int64_t row_low, row_high;
if (split) {
+ const int64_t rounding = get_row_rounding(src0->type);
+
row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
- row_low -= row_low % GGML_CUDA_MMQ_Y;
+ row_low -= row_low % rounding;
if (id == g_device_count - 1) {
row_high = nrows0;
} else {
row_high = nrows0*g_tensor_split[id + 1];
- row_high -= row_high % GGML_CUDA_MMQ_Y;
+ row_high -= row_high % rounding;
}
} else {
row_low = 0;
@@ -5616,14 +5910,16 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
row_low = 0;
row_high = nrows;
} else if (backend == GGML_BACKEND_GPU_SPLIT) {
+ const int64_t rounding = get_row_rounding(tensor->type);
+
row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
- row_low -= row_low % GGML_CUDA_MMQ_Y;
+ row_low -= row_low % rounding;
if (id == g_device_count - 1) {
row_high = nrows;
} else {
row_high = nrows*g_tensor_split[id + 1];
- row_high -= row_high % GGML_CUDA_MMQ_Y;
+ row_high -= row_high % rounding;
}
} else {
GGML_ASSERT(false);