aboutsummaryrefslogtreecommitdiff
path: root/ggml-opencl-dequant.cl
blob: 191b2e57500ad060524732c964ea2b03691e2a57 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#define MULTILINE_QUOTE(...) #__VA_ARGS__
const char * clblast_dequant = MULTILINE_QUOTE(

struct block_q4_0
{
    float d;
    uchar qs[16];
};

__kernel void dequantize_row_q4_0(__global struct block_q4_0* blocks, __global float* result) {
    const uint i = get_global_id(0) / 32;
    const uint l = get_local_id(0);

    const float d = blocks[i].d;

    const uchar vi = blocks[i].qs[l];

    const uint index = i*32 + l*2;
    result[index + 0] = ((vi & 0xf) - 8)*d;
    result[index + 1] = ((vi >> 4) - 8)*d;
}

struct block_q4_1
{
    float d;
    float m;
    uchar qs[16];
};

__kernel void dequantize_row_q4_1(__global struct block_q4_1* blocks, __global float* result) {
    const uint i = get_global_id(0) / 32;
    const uint l = get_local_id(0);

    const float d = blocks[i].d;
    const float m = blocks[i].m;

    const uchar vi = blocks[i].qs[l];

    const uint index = i*32 + l*2;
    result[index + 0] = (vi & 0xf) * d + m;
    result[index + 1] = (vi >> 4) * d + m;
}

struct block_q4_2
{
    ushort d;
    uchar qs[8];
};

__kernel void dequantize_row_q4_2(__global struct block_q4_2* blocks, __global float* result) {
    const uint i = get_global_id(0) / 16;
    const uint l = get_local_id(0);

    const float d = vload_half(0, (__global half*) &blocks[i].d);;

    const uchar vi = blocks[i].qs[l];

    const uint index = i*16 + l*2;
    result[index + 0] = ((vi & 0xf) - 8)*d;
    result[index + 1] = ((vi >> 4) - 8)*d;
}

struct block_q4_3
{
    ushort d;
    ushort m;
    uchar qs[8];
};

__kernel void dequantize_row_q4_3(__global struct block_q4_3* blocks, __global float* result) {
    const uint i = get_global_id(0) / 16;
    const uint l = get_local_id(0);

    const float d = vload_half(0, (__global half*) &(blocks[i].d));
    const float m = vload_half(0, (__global half*) &(blocks[i].m));

    const uchar vi = blocks[i].qs[l];

    const uint index = i*16 + l*2;
    result[index + 0] = (vi & 0xf) * d + m;
    result[index + 1] = (vi >> 4) * d + m;
}

);