Skip to content

Commit 390f0a9

Browse files
Faster dequantization
1 parent 8d8de07 commit 390f0a9

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

ggml-cuda.cu

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,17 @@ static __global__ void dequantize_block_q4_0(const void * vx, float * y, int k)
8686
const int i = blockIdx.x*blockDim.x + threadIdx.x;
8787

8888
if (i < k) {
89-
const float d = x[i].d;
89+
const float d = x[i/QK4_0].d;
9090

91-
const uint8_t * pp = x[i].qs;
91+
const uint8_t * pp = x[i/QK4_0].qs;
9292

93-
for (int l = 0; l < QK4_0; l += 2) {
94-
const uint8_t vi = pp[l/2];
93+
const uint8_t vui = pp[(i%QK4_0)/2];
9594

96-
const int8_t vi0 = vi & 0xf;
97-
const int8_t vi1 = vi >> 4;
95+
const int8_t vi = (vui >> (4 * (i&1))) & 0xF;
9896

99-
const float v0 = (vi0 - 8)*d;
100-
const float v1 = (vi1 - 8)*d;
97+
const float v = (vi - 8)*d;
10198

102-
y[i*QK4_0 + l + 0] = v0;
103-
y[i*QK4_0 + l + 1] = v1;
104-
}
99+
y[i] = v;
105100
}
106101
}
107102

@@ -238,7 +233,7 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y, int k)
238233
}
239234

240235
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
241-
const int nb = k / QK4_0;
236+
const int nb = k;
242237
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
243238
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0));
244239
int grid_size = (nb + block_size - 1) / block_size; // Round up.

0 commit comments

Comments
 (0)