aboutsummaryrefslogtreecommitdiff
path: root/ggml-cuda.cu
AgeCommit message (Collapse)Author
2023-06-17ggml : fix warnings under MSVC (#1908)Howard Su
2023-06-16CUDA : faster k-quant dot kernels (#1862)Kawrakow
* cuda : faster k-quant dot kernels * Imrove Q2_K dot kernel on older GPUs We now have a K_QUANTS_PER_ITERATION macro, which should be set to 1 on older and to 2 on newer GPUs. With this, we preserve the performance of the original PR on RTX-4080, and are faster compared to master on GTX-1660. * Imrove Q6_K dot kernel on older GPUs Using the same K_QUANTS_PER_ITERATION macro as last commit, we preserve performance on RTX-4080 and speed up Q6_K on a GTX-1660. * Add LLAMA_CUDA_KQUANTS_ITER to CMakeLists.txt and Makefile Allowed values are 1 or 2. 2 gives the best performance on modern GPUs and is set as default. On older GPUs 1 may work better. * PR comments --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
2023-06-15Fixed CUDA runtime version check (#1879)Johannes Gäßler
2023-06-15Fix the validation of main device (#1872)Howard Su
2023-06-14CUDA full GPU acceleration, KV cache in VRAM (#1827)Johannes Gäßler
* Fixed CUDA RoPE * ggml_cuda_mul_mat_vec_p021 * ggml_cuda_scale * ggml_cuda_diag_mask_inf * ggml_is_permuted * ggml_cuda_cpy * flatten rows for ggml_cuda_op * Added a --low-vram option * Fixed Windows performance * Fixed LLAMA_CUDA_DMMV_Y > 1 for WizardLM
2023-06-12Leverage mmap for offloading tensors to GPU (#1597)Howard Su
* Rebase to latest * Show progress * Add assert to make sure we only allocate temp buffer for non-CPU backend tensor Co-authored-by: Johannes Gäßler <johannesg@5d6.de> --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
2023-06-11Fixed WSL cuda's OOM error (#1594)Kyle Liang
* In the function , add the cuda error bypass. * remove excessive codes and prints --------- Co-authored-by: liang <liangmanlai@126.com>
2023-06-09Windows nvcc workaround (#1753)Johannes Gäßler
Fix gibberish output on Windows when using CUDA
2023-06-07k-quants : allow to optionally disable at compile time (#1734)Georgi Gerganov
* k-quants : put behind optional compile flag LLAMA_K_QUANTS * build : enable k-quants by default
2023-06-06Multi GPU support, CUDA refactor, CUDA scratch buffer (#1703)Johannes Gäßler
* CUDA multi GPU + scratch ggml_cuda_compute_forward Tensor parallelism ggml_cuda_add ggml_cuda_rms_norm ggml_cuda_silu CUDA scratch buffer --main-gpu CLI option
2023-06-05ggml : add SOTA 2,3,4,5,6 bit k-quantizations (#1684)Kawrakow
* Starting to add k-quantization to ggml I think it is better to have quantization separate from ggml. For now just adding the k-quants there, but it would be better to also factor out the existing ggml quantizations. * Adding Q3_K and Q8_K (de)-quantization * Q3_K now working on CUDA and AVX2/scalar CUDA is not ideal - ~50% slower than Q4_0 for single token prediction, about the same in batch mode (perplexity). CPU single token is ~55 ms (on Ryzen 7950X). * Some improvement for Q3_K on CUDA It is now ~22.5 ms/token on my GPU, so ~30% slower than Q4_0. * Some more CUDA optimizations for Q3_K Single token is now 20.5 ms/token (~20% slower than Q4_0). Perplexity is on par with Q4_0. * Adding Q4_K - scalar, AVX2, CUDA Performance is the same or perhaps very slightly better than Q4_0 on the CPU. On the GPU, single token prediction is ~10% better than Q4_0, batch mode (perplexity is about the same). * Adding Q6_K - scalar, AVX2, CUDA Performance is ~40% lower compared to Q4_K on the CPU. This is to be expected, considering that we are memory bound on the CPU and the 6-bit model is ~44% larger than the 4-bit. On the GPU, single token prediction is ~6% lower than Q4_0, batch mode (perplexity) is even closer (but still slower). * Adding Q5_K - scalar, AVX2, CUDA Performance is ~20% lower compared to Q4_K on the CPU. This is to be expected, considering that we are memory bound on the CPU and the 5-bit model is ~22% larger than the 4-bit. On the GPU, single token prediction is about the same as Q4_0 for both, single token and batch prediction. * Per convention, all QX_K quantizations use Q5_K for output.weight * Adding quantization mixes * Quantization mixes: didn't quite get what I wanted in the last commit * Q4_K dot product for ARM_NEON * Q6_K dot product for ARM_NEON * Q5_K dot product for ARM_NEON * Adding Q3_K dot for ARM_NEON It is 22% slower than Q4_K, despite the smaller model size. On x86_64, where we are memory bound, the Q3_K model is quite a bit faster than Q4_K. * A very slightly faster ARM_NEON Q3_K dot * Adding Q2_K - just CUDA for now Token prediction is pretty good - about 15.5 ms on a RTX 4080. Perplexity is about the same as Q4_K. * Adding scalar and AVX2 Q2_K dot * Adding ARM_NEON Q2_K dot About the same performance as Q4_K. * A slightly faster ARM_NEON Q2_K dot Single token prediction is now ~36 ms on M2 Max. The code is much simpler too. * Fixed bug in Q2_K CUDA dot product kernel Stranegly enough, for the few prompts I tried with the 7B model the responses looked perfectly reasonable. Only realized something is not quite right when I tried the larger models and started getting nonse back. In any case, Q2_K single token evaluation time on an RTX 4080 in a Ryzen7950X box iusing CUDA and model fully loaded on the GPU are ~15.5 ms for 7B, ~25.4 ms for 13B, and ~55.8 ms for 30B. The max number of layers that fit in VRAM for The 65B is 32. With that, we get ~330 ms per token, which is not that much faster than just running on the CPU (~470 ms per token). * Don't print zeros/NaNs when no count histogram has been collected * A 10% faster CUDA vector dot kernel for Q3_K Q3_K is now running at ~18.5 ms / token on CUDA, so the gap to Q4_0 is only 10%. It seems memory acccess pattern is more important for performance than the amount of computation the kernel does. * A slightly daster Q4_K AVX2 dot product For perplexity, where we are less memory bound, time per pass drops by ~5%. Barely measurable difference for single token prediction. * A slightly faster ARM_NEON A4_K dot product * Minor * Fix quantization error test We cannot possibly be expecting rmse < 0.002 for 2- and 3-bit quantization variants. * Fix docker build I have been sloppy with vector reinterpret casts on ARM_NEON. It seems clang is very forgiving in that regard. * Added forgotten ggml.o dependence on k_quants.h to the Makefile * Had unintentionally committed the Makefile with -Ofast enabled * ggml : rename k_quants -> ggml-quants-k, use lowercase in code --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-05-26cuda : performance optimizations (#1530)Johannes Gäßler
* xor hack * block y dim * loop unrolling * Fixed cmake LLAMA_CUDA_BY option * Removed hipblas compatibility code * Define GGML_CUDA_DMMV_BLOCK_Y if not defined * Fewer iters, more ops per iter * Renamed DMMV X/Y compilation options
2023-05-20cuda : loading models directly into VRAM, norm calculation on GPU, ↵Johannes Gäßler
broadcasting for ggml_mul (#1483) * Broadcasting for ggml_mul * CUDA kernel for ggml_mul, norms in VRAM * GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * fixup! GPU weights not in RAM, direct loading with cuFile * define default model path once, sync path with readme (#1366) * ~7% faster Q5_1 AVX2 code (#1477) * convert.py: Support models which are stored in a single pytorch_model.bin (#1469) * Support models in a single pytorch_model.bin * Remove spurious line with typo * benchmark-matmul: Print the average of the test results (#1490) * Remove unused n_parts parameter (#1509) * Fixes #1511 lambda issue for w64devkit (mingw) (#1513) * Fix for w64devkit and mingw * make kv_f16 the default for api users (#1517) * minor : fix compile warnings * readme : adds WizardLM to the list of supported models (#1485) * main : make reverse prompt option act as a stop token in non-interactive mode (#1032) * Make reverse prompt option act as a stop token in non-interactive scenarios * Making requested review changes * Update gpt_params_parse and fix a merge error * Revert "Update gpt_params_parse and fix a merge error" This reverts commit 2bb2ff1748513591ad45b175a75ed1d8089d84c8. * Update gpt_params_parse and fix a merge error take 2 * examples : add persistent chat (#1495) * examples : add persistent chat * examples : fix whitespace --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * tests : add missing header * ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508) * ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0 * llama : bump LLAMA_FILE_VERSION to 3 * cuda : update Q4 and Q8 dequantize kernels * ggml : fix AVX dot products * readme : update performance table + hot topics * ggml : fix scalar implementation of Q4_1 dot * llama : fix compile warnings in llama_set_state_data() * llama : fix name shadowing and C4146 (#1526) * Fix name shadowing and C4146 * Fix if macros not using defined when required * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Code style Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Fix for mingw (#1462) * llama : add llama_init_backend() API (close #1527) * feature : add blis and other BLAS implementation support (#1502) * feature: add blis support * feature: allow all BLA_VENDOR to be assigned in cmake arguments. align with whisper.cpp pr 927 * fix: version detection for BLA_SIZEOF_INTEGER, recover min version of cmake * Fix typo in INTEGER Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * Revert "feature : add blis and other BLAS implementation support (#1502)" This reverts commit 07e9ace0f9da424d82e75df969642522880feb92. * GPU weights not in RAM, direct loading with cuFile * llama : code style fixes + progress print fix * ggml : ggml_mul better broadcast support * cmake : workarounds for cufile when CMake version < 3.25 * gg rebase fixup * Loop in llama.cpp, fixed progress callback * Attempt clang-tidy fix * llama : fix vram size computation * Add forgotten fclose() --------- Co-authored-by: András Salamon <ott2@users.noreply.github.com> Co-authored-by: Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> Co-authored-by: Tom Jobbins <784313+TheBloke@users.noreply.github.com> Co-authored-by: rankaiyx <rankaiyx@rankaiyx.com> Co-authored-by: Stephan Walter <stephan@walter.name> Co-authored-by: DannyDaemonic <DannyDaemonic@gmail.com> Co-authored-by: Erik Scholz <Green-Sky@users.noreply.github.com> Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: David Kennedy <dakennedyd@gmail.com> Co-authored-by: Jason McCartney <jmac@theroot.org> Co-authored-by: Evan Jones <evan.q.jones@gmail.com> Co-authored-by: Maxime <672982+maximegmd@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zenix <zenixls2@gmail.com>
2023-05-19ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508)Georgi Gerganov
* ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0 * llama : bump LLAMA_FILE_VERSION to 3 * cuda : update Q4 and Q8 dequantize kernels * ggml : fix AVX dot products * readme : update performance table + hot topics
2023-05-14cuda : deduplicated dequantization code (#1453)Johannes Gäßler
2023-05-13cuda : fix convert function (#1412)Georgi Gerganov
2023-05-13ggml : GPU-accelerated token generation (#1412)Johannes Gäßler
* CUDA kernel for q4_0 dequant. + mat. vec. mult. * Added q4_1 via template * Added missing __syncthreads(); * --gpu_layers -> --gpu-layers * Shorter dequantize_mul_mat_vec line * q5_0 dequantize_mul_mat kernel * More readable dequantize_mul_mat_vec logic * dequantize_mul_mat_vec kernels for q5_1, q8_0, f16 * llama : offload "output" tensor to GPU too + coding style fixes --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
2023-05-12ggml : remove bit shuffling (#1405)Georgi Gerganov
* ggml : remove Q4_0 bit shufling (ARM NEON) * ggml : remove Q4_1 bit shuffling (ARM NEON + reference) * ggml : nibbles_from_floats() + bytes_from_nibbles() (ARM NEON) * ggml : remove Q4_2 bit shuffling (WIP, BROKEN) * ggml : remove Q5_0 bit shuffling (ARM NEON) * ggml : 2x faster scalar implementations * ggml : remove Q5_1 bit shuffling (ARM NEON + scalar) * ggml : simplify scalar dot * ggml : remove WASM SIMD bit shuffling + remove vzip for ARM 32-bit * ggml : fix Q4_1 quantization * ggml : update cuBLAS + normalize variable names * ggml : remove Q4_2 mode * ggml : minor formatting * ggml : fix Q5_0 quantization * scripts : add script for measuring the time per token * AVX implementations (#1370) * ggml : uniform 5th bit extraction * llama : produce error upon loading old model files * llama : fix model magic/version write * ggml : speed-up Q5_0 + Q5_1 at 4 threads * ggml : preserve old Q4 and Q5 formats * ggml : simplify Q8_1 - no need for low / high sums anymore * ggml : fix Q8_0 and Q8_1 rounding * Revert "AVX implementations (#1370)" This reverts commit 948d124837f9d287d8490f41338e0e4cceb0814f. * ggml : fix AVX2 implementation * sha : update hashes for 7B and 13B * readme : update timings + remove warning banner * llama : update v2 PR number to 1405 * ggml : fix WASM comments * ggml : back to original bit order * readme : add note that Q4 and Q5 have been changed * llama : fix return for unknown version --------- Co-authored-by: Stephan Walter <stephan@walter.name>
2023-05-08Documented CUDA reproducibility, added warning (#1346)Johannes Gäßler
2023-05-01cuBLAS: refactor and optimize f16 mat mul performance (#1259)slaren
* cuBLAS: refactor, convert fp16 to fp32 on device * cuBLAS: use multiple streams, choose smartly between mul_mat_q and mul_mat_f16 * fix build * cuBLAS: update block_q5_1
2023-05-01cuBLAS: fall back to pageable memory if pinned alloc fails (#1233)slaren
* cuBLAS: fall back to pageable memory if pinned alloc fails * cuBLAS: do not use pinned memory if env variable GGML_CUDA_NO_PINNED is set
2023-04-29cuBLAS: use host pinned memory and dequantize while copying (#1207)slaren
* cuBLAS: dequantize simultaneously while copying memory * cuBLAS: use host pinned memory * cuBLAS: improve ggml_compute_forward_mul_mat_f16_f32 with pinned memory * cuBLAS: also pin kv cache * fix rebase
2023-04-29cuBLAS: non-contiguous tensor support (#1215)Henri Vasserman
* Cuda: non-contiguous tensor support * remove extra stuff * rename * fix error * more fixes, now OpenBLAS and CLBlast build too * now then?
2023-04-28Remove Q4_3 which is no better than Q5 (#1218)Stephan Walter
2023-04-26ggml : add Q5_0 and Q5_1 quantization (#1187)Georgi Gerganov
* ggml : add Q5_0 quantization (cuBLAS only) * ggml : fix Q5_0 qh -> uint32_t * ggml : fix q5_0 histogram stats * ggml : q5_0 scalar dot product * ggml : q5_0 ARM NEON dot * ggml : q5_0 more efficient ARM NEON using uint64_t masks * ggml : rename Q5_0 -> Q5_1 * ggml : adding Q5_0 mode * quantize : add Q5_0 and Q5_1 to map * ggml : AVX2 optimizations for Q5_0, Q5_1 (#1195) --------- Co-authored-by: Stephan Walter <stephan@walter.name>
2023-04-25ggml : add Q8_0 quantization format (rename the old one to Q8_1) (ARM NEON) ↵Georgi Gerganov
(#1179) * ggml : add Q8_0 quantization format (rename the old one to Q8_1) * tests : fix test-quantize-fns * ggml : finalize Q8_0 implementation * ggml : use q4_0_q8_0 and q4_2_q8_0 * ggml : fix Q8_0 dot product bug (ARM) * ggml : Q8_0 unroll x2 * ggml : fix bug - using wrong block type * ggml : extend quantize_fns_t with "vec_dot_type" * ggml : fix Q8_0 to use 255 values out of 256 * ggml : fix assert using wrong QK4_2 instead of QK4_3
2023-04-21Improve cuBLAS performance by using a memory pool (#1094)slaren
* Improve cuBLAS performance by using a memory pool * Move cuda specific definitions to ggml-cuda.h/cu * Add CXX flags to nvcc * Change memory pool synchronization mechanism to a spin lock General code cleanup
2023-04-20Add Q4_3 support to cuBLAS (#1086)slaren
2023-04-20Improve cuBLAS performance by dequantizing on the GPU (#1065)slaren