@@ -1434,6 +1434,59 @@ static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy,
1434
1434
reinterpret_cast <sycl::half &>(y[ib].ds .y ()) = sum;
1435
1435
}
1436
1436
1437
+ template <int ElementsPerWI>
1438
+ static __dpct_inline__ void quantize_and_reorder_q8_1 (const float * __restrict__ x, void * reordered_q8_tensor,
1439
+ const int kx, const int kx_padded, const sycl::nd_item<1 > & it) {
1440
+ /*
1441
+ Quantizes and reorders the resultant q8 tensor in a per row fashion
1442
+ Each sub-group calculates one quant block. i.e. QK8_1 quant values and the d and sum values
1443
+ */
1444
+
1445
+ auto subgroup_id = it.get_group (0 );
1446
+ auto wi_id = it.get_local_id (0 );
1447
+
1448
+ const int num_blocks_per_row = kx / QK8_1;
1449
+ auto row = subgroup_id / num_blocks_per_row;
1450
+ auto col = subgroup_id % num_blocks_per_row;
1451
+
1452
+ auto row_offset = row * (kx_padded / QK8_1) * sizeof (block_q8_1);
1453
+ auto col_offset = QK8_1 * col + wi_id * ElementsPerWI;
1454
+
1455
+ auto quant_ptr = (int8_t *) ((char *) reordered_q8_tensor + row_offset + col_offset);
1456
+ auto ds_ptr = (sycl::half2 *) ((char *) reordered_q8_tensor + row_offset + kx + col * sizeof (sycl::half2));
1457
+
1458
+ sycl::vec<float , ElementsPerWI> wi_f32_vals;
1459
+ sycl::vec<int8_t , ElementsPerWI> quantized_values;
1460
+
1461
+ auto float_ptr_offset = subgroup_id * QK8_1 + ElementsPerWI * wi_id;
1462
+ wi_f32_vals = *reinterpret_cast <const sycl::vec<float , ElementsPerWI> *>(x + float_ptr_offset);
1463
+
1464
+ float sum = 0 .0f ;
1465
+ float amax = 0 .0f ;
1466
+
1467
+ #pragma unroll(ElementsPerWI)
1468
+ for (int i = 0 ; i < ElementsPerWI; i++) {
1469
+ sum += wi_f32_vals[i];
1470
+ amax = sycl::fmax (amax, sycl::fabs (wi_f32_vals[i]));
1471
+ quantized_values[i] = 0 ;
1472
+ }
1473
+ sum = sycl::reduce_over_group (it.get_group (), sum, sycl::plus<float >());
1474
+ amax = sycl::reduce_over_group (it.get_group (), amax, sycl::maximum<float >());
1475
+ float d = amax == 0 ? 1 : amax / 127 ;
1476
+
1477
+ #pragma unroll(ElementsPerWI)
1478
+ for (int i = 0 ; i < ElementsPerWI; i++) {
1479
+ quantized_values[i] = sycl::round (wi_f32_vals[i] / d);
1480
+ }
1481
+
1482
+ d = amax == 0 ? 0 : d;
1483
+
1484
+ *reinterpret_cast <sycl::vec<int8_t , ElementsPerWI> *>(quant_ptr) = quantized_values;
1485
+ if (wi_id == 0 ) {
1486
+ *ds_ptr = sycl::half2 (sycl::half (d), sycl::half (sum));
1487
+ }
1488
+ }
1489
+
1437
1490
static void mul_mat_p021_f16_f32 (
1438
1491
const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
1439
1492
const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
@@ -1718,23 +1771,30 @@ static void pool2d_nchw_kernel(
1718
1771
o_ptr[cur_oh * ow + cur_ow] = res;
1719
1772
}
1720
1773
1721
- static void quantize_row_q8_1_sycl (const float *x, void *vy, const int kx,
1722
- const int ky, const int kx_padded,
1723
- queue_ptr stream) {
1724
- const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1 ) / SYCL_QUANTIZE_BLOCK_SIZE;
1725
- const sycl::range<3 > num_blocks (1 , ky, block_num_x);
1726
- int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1727
- static_assert (QK8_1 % WARP_SIZE == 0 );
1728
- const sycl::range<3 > block_size (1 , 1 , SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1729
- {
1730
- dpct::has_capability_or_fail (stream->get_device (),
1731
- {sycl::aspect::fp16});
1774
+ static void quantize_row_q8_1_sycl (const float * x, void * vy, const int kx, const int ky, const int kx_padded,
1775
+ bool reorder_q8_tensor, queue_ptr stream) {
1776
+ if (reorder_q8_tensor) {
1777
+ auto local_range = std::size_t (WARP_SIZE);
1778
+ auto num_quant_blocks = ky * (kx / QK8_1);
1779
+ auto global_range = num_quant_blocks * local_range;
1780
+ stream->parallel_for (sycl::nd_range<1 >({ global_range }, { local_range }),
1781
+ [=](sycl::nd_item<1 > it) [[sycl::reqd_sub_group_size (WARP_SIZE)]] {
1782
+ quantize_and_reorder_q8_1<QK8_1 / WARP_SIZE>(x, vy, kx, kx_padded, it);
1783
+ });
1784
+ } else {
1785
+ const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1 ) / SYCL_QUANTIZE_BLOCK_SIZE;
1786
+ const sycl::range<3 > num_blocks (1 , ky, block_num_x);
1787
+ int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
1788
+ static_assert (QK8_1 % WARP_SIZE == 0 );
1789
+ const sycl::range<3 > block_size (1 , 1 , SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
1790
+ {
1791
+ dpct::has_capability_or_fail (stream->get_device (), { sycl::aspect::fp16 });
1732
1792
1733
- stream->parallel_for (
1734
- sycl::nd_range <3 >(num_blocks * block_size, block_size),
1735
- [=](sycl::nd_item< 3 > item_ct1) [[ sycl::reqd_sub_group_size (WARP_SIZE)]] {
1736
- quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1 );
1737
- });
1793
+ stream->parallel_for (sycl::nd_range< 3 >(num_blocks * block_size, block_size),
1794
+ [=]( sycl::nd_item <3 > item_ct1) [[ sycl::reqd_sub_group_size (WARP_SIZE)]] {
1795
+ quantize_q8_1<QUANT_BLOCK_TILE>(x, vy, kx, kx_padded, item_ct1);
1796
+ } );
1797
+ }
1738
1798
}
1739
1799
}
1740
1800
@@ -2446,9 +2506,10 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2446
2506
dev[i].src1_ddq = dev[i].src1_ddq_alloc .alloc (ctx.pool (i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
2447
2507
2448
2508
if (src1_on_device && src1_is_contiguous) {
2509
+ bool reorder_q8_tensor = src0->extra && ((ggml_tensor_extra_gpu *)src0->extra )->optimized_feature .reorder ;
2449
2510
scope_op_debug_print scope_dbg_print (__func__, " /quantize_row_q8_1_sycl" , dst,
2450
2511
/* num_src=*/ 2 , " : converting src1 to Q8_1" );
2451
- quantize_row_q8_1_sycl (dev[i].src1_ddf , dev[i].src1_ddq , ne10, nrows1, src1_padded_col_size, stream);
2512
+ quantize_row_q8_1_sycl (dev[i].src1_ddf , dev[i].src1_ddq , ne10, nrows1, src1_padded_col_size, reorder_q8_tensor, stream);
2452
2513
/*
2453
2514
DPCT1010:90: SYCL uses exceptions to report errors and does not
2454
2515
use the error codes. The call was replaced with 0. You need to
@@ -2554,7 +2615,7 @@ static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_ten
2554
2615
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
2555
2616
scope_op_debug_print scope_dbg_print (__func__, " /quantize_row_q8_1_sycl" , dst,
2556
2617
/* num_src=*/ 2 , " : converting src1 to Q8_1" );
2557
- quantize_row_q8_1_sycl (src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
2618
+ quantize_row_q8_1_sycl (src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, false , stream);
2558
2619
/*
2559
2620
DPCT1010:92: SYCL uses exceptions to report errors and does
2560
2621
not use the error codes. The call was replaced with 0. You
0 commit comments