@@ -358,10 +358,11 @@ static void pad_f32(const float *x, float *dst, const int ne0, const int ne00,
358
358
}
359
359
}
360
360
361
+ template <int QUANT_BLOCK_TILE>
361
362
static void quantize_q8_1 (const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
362
363
const sycl::nd_item<3 > &item_ct1) {
363
- const int ix = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
364
- item_ct1.get_local_id (2 );
364
+ const int ix = ( item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
365
+ item_ct1.get_local_id (2 )) * QUANT_BLOCK_TILE ;
365
366
366
367
if (ix >= kx_padded) {
367
368
return ;
@@ -376,23 +377,39 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
376
377
377
378
const int ib = i_padded / QK8_1; // block index
378
379
const int iqs = i_padded % QK8_1; // quant index
379
-
380
- const float xi = ix < kx ? x[iy*kx + ix] : 0 .0f ;
381
- float amax = sycl::fabs ((float )xi);
382
- float sum = xi;
383
-
380
+ typedef sycl::vec<float , QUANT_BLOCK_TILE> TC;
381
+ typedef sycl::vec<int8_t , QUANT_BLOCK_TILE> TQ;
382
+ TC zeros;
383
+ TQ qzeros;
384
384
#pragma unroll
385
- for (int mask = WARP_SIZE / 2 ; mask > 0 ; mask >>= 1 ) {
386
- amax = sycl::fmax (amax, dpct::permute_sub_group_by_xor (
387
- item_ct1.get_sub_group (), amax, mask));
388
- sum +=
389
- dpct::permute_sub_group_by_xor (item_ct1.get_sub_group (), sum, mask);
385
+ for (int i = 0 ; i < QUANT_BLOCK_TILE; i++)
386
+ {
387
+ zeros[i] = 0 .f ;
388
+ qzeros[i] = 0 ;
389
+ }
390
+ const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
391
+ float sum = xi[0 ];
392
+ float amax = sycl::fabs (xi[0 ]);
393
+ #pragma unroll
394
+ for (int i = 1 ; i < QUANT_BLOCK_TILE; i++)
395
+ {
396
+ sum += xi[i];
397
+ amax = sycl::fmax (sycl::fabs (xi[i]), amax);
390
398
}
399
+ sum = warp_reduce_sum (sum, item_ct1);
400
+ amax = warp_reduce_max (amax, item_ct1);
391
401
392
402
const float d = amax / 127 ;
393
- const int8_t q = amax == 0 .0f ? 0 : sycl::round (xi / d);
403
+ TQ q = qzeros;
404
+ if (amax != 0 .0f )
405
+ {
406
+ #pragma unroll
407
+ for (int i = 0 ; i < QUANT_BLOCK_TILE; i++) {
408
+ q[i] = sycl::round (xi[i] / d);
409
+ }
410
+ }
394
411
395
- y[ib].qs [iqs] = q;
412
+ *(TQ *)& y[ib].qs [iqs] = q;
396
413
397
414
if (iqs > 0 ) {
398
415
return ;
@@ -1595,15 +1612,17 @@ static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
1595
1612
queue_ptr stream) {
1596
1613
const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1 ) / SYCL_QUANTIZE_BLOCK_SIZE;
1597
1614
const sycl::range<3 > num_blocks (1 , ky, block_num_x);
1598
- const sycl::range<3 > block_size (1 , 1 , SYCL_DEQUANTIZE_BLOCK_SIZE);
1615
+ int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1616
+ static_assert (QK8_1 % WARP_SIZE == 0 );
1617
+ const sycl::range<3 > block_size (1 , 1 , SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1599
1618
{
1600
1619
dpct::has_capability_or_fail (stream->get_device (),
1601
1620
{sycl::aspect::fp16});
1602
1621
1603
1622
stream->parallel_for (
1604
1623
sycl::nd_range<3 >(num_blocks * block_size, block_size),
1605
1624
[=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (WARP_SIZE)]] {
1606
- quantize_q8_1 (x, vy, kx, kx_padded, item_ct1);
1625
+ quantize_q8_1<QUANT_BLOCK_TILE> (x, vy, kx, kx_padded, item_ct1);
1607
1626
});
1608
1627
}
1609
1628
}
@@ -4170,7 +4189,6 @@ bool ggml_sycl_supports_dmmv(enum ggml_type type) {
4170
4189
4171
4190
static void ggml_sycl_mul_mat (ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
4172
4191
const bool split = ggml_backend_buffer_is_sycl_split (src0->buffer );
4173
-
4174
4192
int64_t min_compute_capability = INT_MAX;
4175
4193
4176
4194
if (split) {
0 commit comments