3
3
#include < ideep.hpp>
4
4
#include " aten/Linear.h"
5
5
#include " aten/WeightPack.h"
6
+ #include " csrc/cpu/tpp/woq/tla.h"
6
7
#include " ideep/IDeepConversions.h"
7
8
8
9
namespace torch_ipex {
@@ -16,6 +17,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContext(
16
17
at::Tensor&& scales,
17
18
at::Tensor&& zero_points,
18
19
c10::optional<at::Tensor>&& bias,
20
+ c10::optional<at::Tensor>&& g_idx,
19
21
c10::optional<int64_t > batch_size,
20
22
bool is_int4,
21
23
int64_t group_size,
@@ -32,6 +34,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContext(
32
34
std::move (scales),
33
35
std::move (zero_points),
34
36
std::move (bias),
37
+ std::move (g_idx),
35
38
batch_size,
36
39
is_int4,
37
40
group_size,
@@ -45,6 +48,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContextInt4(
45
48
at::Tensor&& scales,
46
49
at::Tensor&& zero_points,
47
50
c10::optional<at::Tensor>&& bias,
51
+ c10::optional<at::Tensor>&& g_idx,
48
52
c10::optional<int64_t > batch_size,
49
53
int64_t group_size, // group_size along input channel
50
54
int64_t lowp_mode,
@@ -160,6 +164,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContextInt4(
160
164
std::move (scales_fp32),
161
165
std::move (zp_fp32),
162
166
std::move (bias),
167
+ std::move (g_idx),
163
168
batch_size,
164
169
/* is_int4*/ true ,
165
170
group_size,
@@ -183,17 +188,29 @@ ContextLinearWoq create(
183
188
at::Tensor& scales,
184
189
at::Tensor& zero_points,
185
190
const c10::optional<at::Tensor>& bias,
191
+ const c10::optional<at::Tensor>& g_idx,
186
192
const c10::optional<int64_t > batch_size,
187
193
bool is_int4,
188
194
int64_t group_size,
189
195
int64_t lowp_mode,
190
196
int64_t num_concats,
191
197
int64_t act_quant_mode) {
192
- auto packed_weight = woq_linear_pack_weight (
193
- weight, weight_shape, is_int4, group_size, lowp_mode);
198
+ at::Tensor packed_weight;
199
+ int64_t N = weight_shape[0 ];
200
+ int64_t K = weight_shape[1 ];
201
+ // GPTQ with act-order
202
+ // Shuffle weight along ic to make channels contiguous in group
203
+ if (is_int4 && group_size > 0 && g_idx.has_value ()) {
204
+ // Shuffle weight along ic to make channels contiguous in group
205
+ auto shuffled_weight = woq_shuffle_tensor_by_group_idx</* is_int4 */ true >(
206
+ weight, weight_shape, g_idx.value (), group_size);
207
+ packed_weight = woq_linear_pack_weight (
208
+ shuffled_weight, weight_shape, is_int4, group_size, lowp_mode);
209
+ } else {
210
+ packed_weight = woq_linear_pack_weight (
211
+ weight, weight_shape, is_int4, group_size, lowp_mode);
212
+ }
194
213
auto packed_shape = packed_weight.sizes ();
195
- int64_t N = weight.size (0 );
196
- int64_t K = weight.size (1 );
197
214
// If OC is not a multiple of BLOCK_N, it may be padded.
198
215
bool oc_is_padded = (packed_shape.size () == 4 && is_int4 &&
199
216
packed_shape[0 ] * packed_shape[3 ] * 2 != N) ||
@@ -221,6 +238,7 @@ ContextLinearWoq create(
221
238
std::move (scales_padded),
222
239
std::move (zero_points_padded),
223
240
c10::make_optional (bias_padded),
241
+ g_idx.has_value () ? c10::make_optional (*g_idx) : c10::nullopt,
224
242
is_int4,
225
243
group_size,
226
244
lowp_mode,
@@ -233,6 +251,7 @@ ContextLinearWoq create(
233
251
std::move (scales_padded),
234
252
std::move (zero_points_padded),
235
253
c10::nullopt,
254
+ g_idx.has_value () ? c10::make_optional (*g_idx) : c10::nullopt,
236
255
is_int4,
237
256
group_size,
238
257
lowp_mode,
@@ -246,13 +265,30 @@ ContextLinearWoq create(
246
265
std::move (scales),
247
266
std::move (zero_points_float),
248
267
bias.has_value () ? c10::make_optional (*bias) : c10::nullopt,
268
+ g_idx.has_value () ? c10::make_optional (*g_idx) : c10::nullopt,
249
269
is_int4,
250
270
group_size,
251
271
lowp_mode,
252
272
num_concats,
253
273
act_quant_mode);
254
274
}
255
275
276
+ static at::Tensor _shuffle_input_channels_if_needed (
277
+ ContextLinearWoq& context,
278
+ const at::Tensor& input) {
279
+ // GPTQ with act-order
280
+ // Shuffle input channels to align with weight
281
+ if (context.is_int4_ && context.group_size_ > 0 &&
282
+ context.g_idx_ .has_value ()) {
283
+ auto & g_idx = context.g_idx_ .value ();
284
+ auto K = input.size (-1 );
285
+ std::vector<int64_t > input_shape = {input.numel () / K, K};
286
+ return woq_shuffle_tensor_by_group_idx (
287
+ input, input_shape, g_idx, context.group_size_ );
288
+ }
289
+ return input;
290
+ }
291
+
256
292
at::Tensor run (ContextLinearWoq& context, const at::Tensor& input) {
257
293
// TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n)
258
294
auto w_k = context.weight_shape_ [1 ];
@@ -264,6 +300,8 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) {
264
300
w_k,
265
301
" respectively." );
266
302
auto input_ = input.contiguous ();
303
+ // handle GPTQ with act-order
304
+ input_ = _shuffle_input_channels_if_needed (context, input_);
267
305
auto res = woq_linear_kernel (
268
306
input_,
269
307
context.at_weight_ ,
@@ -299,6 +337,8 @@ at::Tensor run_eltwise(
299
337
w_k,
300
338
" respectively." );
301
339
auto input_ = input.contiguous ();
340
+ // handle GPTQ with act-order
341
+ input_ = _shuffle_input_channels_if_needed (context, input_);
302
342
return woq_linear_eltwise_kernel (
303
343
input_,
304
344
context.at_weight_ ,
@@ -330,6 +370,8 @@ at::Tensor run_add(
330
370
w_k,
331
371
" respectively." );
332
372
auto input_ = input.contiguous ();
373
+ // handle GPTQ with act-order
374
+ input_ = _shuffle_input_channels_if_needed (context, input_);
333
375
return woq_linear_add_kernel (
334
376
input_,
335
377
context.at_weight_ ,
@@ -359,6 +401,8 @@ at::Tensor run_add_add(
359
401
w_k,
360
402
" respectively." );
361
403
auto input_ = input.contiguous ();
404
+ // handle GPTQ with act-order
405
+ input_ = _shuffle_input_channels_if_needed (context, input_);
362
406
return woq_linear_add_add_kernel (
363
407
input_,
364
408
context.at_weight_ ,
@@ -380,10 +424,19 @@ at::Tensor pack(ContextLinearWoq& context, const at::Tensor& tensor) {
380
424
at::Tensor unpack (ContextLinearWoq& context, const at::Tensor& tensor) {
381
425
// By using different kernels, the packed weight dim can be 2 or 4
382
426
// Return result directly if dim == 2
383
- // For dim == 4, make a new quantized tensor and return .
427
+ // For dim == 4, weight may be padded .
384
428
// For padded weight (int4), make a slice of it.
385
429
auto unpacked_weight =
386
430
woq_linear_unpack_weight (tensor, context.is_int4_ , context.lowp_mode_ );
431
+ // With g_idx, weight's input channels are shuffled along ic so that
432
+ // those in the same group are contiguous.
433
+ // Here we need to shuffle them to the original order.
434
+ if (context.group_size_ > 0 && context.g_idx_ .has_value ()) {
435
+ auto group_size = context.group_size_ ;
436
+ auto & g_idx = context.g_idx_ .value ();
437
+ unpacked_weight = woq_shuffle_weight_back_by_group_idx (
438
+ unpacked_weight, context.weight_shape_ , g_idx, group_size);
439
+ }
387
440
if (tensor.dim () > 2 ) {
388
441
auto scales = context.scales_list_ [0 ];
389
442
auto zero_points = context.zero_points_list_ [0 ];
@@ -404,6 +457,185 @@ at::Tensor unpack(ContextLinearWoq& context, const at::Tensor& tensor) {
404
457
return unpacked_weight;
405
458
}
406
459
460
+ template <typename T, typename Tg, bool is_int4 = false >
461
+ at::Tensor woq_shuffle_tensor_by_group_idx_impl (
462
+ const at::Tensor& tensor,
463
+ const std::vector<int64_t >& tensor_shape,
464
+ const at::Tensor& g_idx,
465
+ int64_t group_size) {
466
+ // g_idx shape = [ic]
467
+ // i-th element indicates which group tensor[:][i] belongs to.
468
+ // Shuffle tensor along ic to make channels contiguous in group.
469
+ int64_t N = tensor_shape[0 ];
470
+ int64_t K = tensor_shape[1 ];
471
+ auto shuffled_tensor = at::zeros_like (tensor, tensor.dtype ());
472
+ auto shuffled_tensor_data = reinterpret_cast <T*>(shuffled_tensor.data_ptr ());
473
+ auto tensor_data = reinterpret_cast <T*>(tensor.data_ptr ());
474
+ auto num_groups = (K + group_size - 1 ) / group_size;
475
+ auto g_idx_data = reinterpret_cast <Tg*>(g_idx.data_ptr ());
476
+ #pragma omp parallel for
477
+ for (int64_t i = 0 ; i < N; ++i) {
478
+ std::vector<int64_t > counts_per_group (num_groups, 0 );
479
+ auto stride = is_int4 ? K / 2 : K;
480
+ auto tensor_row_data = tensor_data + i * stride;
481
+ auto shuffled_row_data = shuffled_tensor_data + i * stride;
482
+ for (int64_t j = 0 ; j < K; ++j) {
483
+ auto g = g_idx_data[j];
484
+ auto new_idx = g * group_size + counts_per_group[g];
485
+ constexpr bool T_is_int8 =
486
+ std::is_same<T, int8_t >() || std::is_same<T, uint8_t >();
487
+ if constexpr (is_int4 && T_is_int8) {
488
+ uint8_t mask = j % 2 ? 0xF0 : 0x0F ;
489
+ size_t rshift = j % 2 ? 4 : 0 ;
490
+ T data = (tensor_row_data[j / 2 ] & mask) >> rshift;
491
+ shuffled_row_data[new_idx / 2 ] =
492
+ shuffled_row_data[new_idx / 2 ] | (new_idx % 2 ? (data << 4 ) : data);
493
+ } else {
494
+ T data = tensor_row_data[j];
495
+ shuffled_row_data[new_idx] = data;
496
+ }
497
+ ++counts_per_group[g];
498
+ }
499
+ }
500
+ return shuffled_tensor;
501
+ }
502
+
503
+ /* *
504
+ * Shuffle activation or weight tensor along input channel according to group
505
+ * index (g_idx), so that input channels in the same group are contiguous to
506
+ * each other.
507
+ *
508
+ * @param is_int4 The tensor stores int4 data or not
509
+ * @param tensor The tensor to be shuffled. It must be 2d.
510
+ * @param tensor_shape The original shape of the tensor. It is different from
511
+ * tensor.shape() when dtype is int4 since 2 int4 data are packed as one int8.
512
+ * @param g_idx The g_idx tensor contains group index for each input channel.
513
+ * Its shape is [number of input channels]. Indices should be in [0, number of
514
+ * groups).
515
+ * @param group_size The group size of input channels. Used to determine number
516
+ * of groups.
517
+ * @return The shuffled tensor.
518
+ */
519
+ template <bool is_int4>
520
+ at::Tensor woq_shuffle_tensor_by_group_idx (
521
+ const at::Tensor& tensor,
522
+ const std::vector<int64_t >& tensor_shape,
523
+ const at::Tensor& g_idx,
524
+ int64_t group_size) {
525
+ at::Tensor out;
526
+ product_dispatcher<
527
+ std::tuple<at::ScalarType, at::ScalarType>,
528
+ std::tuple<
529
+ enumerate_dispatcher<
530
+ at::ScalarType,
531
+ at::kDouble ,
532
+ at::kFloat ,
533
+ at::kBFloat16 ,
534
+ at::kHalf ,
535
+ at::kChar ,
536
+ at::kByte >,
537
+ enumerate_dispatcher<at::ScalarType, at::kInt , at::kLong >>>::
538
+ call (
539
+ std::make_tuple (tensor.scalar_type (), g_idx.scalar_type ()),
540
+ [&](auto dtype_tuple) {
541
+ auto tensor_dtype = std::get<0 >(dtype_tuple);
542
+ auto g_idx_dtype = std::get<1 >(dtype_tuple);
543
+ using t_cpp_type =
544
+ typename c10::impl::ScalarTypeToCPPType<tensor_dtype>::type;
545
+ using g_cpp_type =
546
+ typename c10::impl::ScalarTypeToCPPType<g_idx_dtype>::type;
547
+ out = woq_shuffle_tensor_by_group_idx_impl<
548
+ t_cpp_type,
549
+ g_cpp_type,
550
+ is_int4>(tensor, tensor_shape, g_idx, group_size);
551
+ },
552
+ [](auto dtype_tuple) {
553
+ TORCH_CHECK (
554
+ false , " Unsupported tensor data type for WOQ with g_idx" );
555
+ });
556
+ return out;
557
+ }
558
+
559
+ template <typename T, typename Tg>
560
+ at::Tensor woq_shuffle_weight_back_by_group_idx_impl (
561
+ const at::Tensor& qweight,
562
+ const std::vector<int64_t >& weight_shape,
563
+ const at::Tensor& g_idx,
564
+ int64_t group_size) {
565
+ auto N = weight_shape[0 ];
566
+ auto K = weight_shape[1 ];
567
+ auto shuffled_tensor = at::zeros_like (qweight, qweight.dtype ());
568
+ auto shuffled_tensor_data = reinterpret_cast <T*>(shuffled_tensor.data_ptr ());
569
+ auto tensor_data = reinterpret_cast <T*>(qweight.data_ptr ());
570
+ auto num_groups = (K + group_size - 1 ) / group_size;
571
+ auto g_idx_data = reinterpret_cast <Tg*>(g_idx.data_ptr ());
572
+ #pragma omp parallel for
573
+ for (int64_t i = 0 ; i < N; ++i) {
574
+ std::vector<int64_t > counts_per_group (num_groups, 0 );
575
+ auto stride = K / 2 ;
576
+ auto tensor_row_data = tensor_data + i * stride;
577
+ auto shuffled_row_data = shuffled_tensor_data + i * stride;
578
+ for (int64_t j = 0 ; j < K; ++j) {
579
+ auto g = g_idx_data[j];
580
+ T* data_pos =
581
+ tensor_row_data + g * group_size / 2 + counts_per_group[g] / 2 ;
582
+ uint8_t mask = counts_per_group[g] % 2 ? 0xF0 : 0x0F ;
583
+ size_t rshift = counts_per_group[g] % 2 ? 4 : 0 ;
584
+ T data = (*data_pos & mask) >> rshift;
585
+ shuffled_row_data[j / 2 ] =
586
+ shuffled_row_data[j / 2 ] | (j % 2 ? (data << 4 ) : data);
587
+ ++counts_per_group[g];
588
+ }
589
+ }
590
+ return shuffled_tensor;
591
+ }
592
+
593
+ /* *
594
+ * Shuffle weight tensor along input channel according to group index (g_idx)
595
+ * to its original order. It is used for unpacking weight. Data type is assumed
596
+ * INT4.
597
+ *
598
+ * @param qweight The weight to be shuffled. It must be 2d.
599
+ * @param weight_shape The original shape of the weight. It is different from
600
+ * tensor.shape() since 2 int4 data are packed as one int8.
601
+ * @param g_idx The g_idx tensor contains group index for each input channel.
602
+ * Its shape is [number of input channels]. Indices should be in [0, number of
603
+ * groups).
604
+ * @param group_size The group size of input channels. Used to determine number
605
+ * of groups.
606
+ * @return The shuffled tensor.
607
+ */
608
+ at::Tensor woq_shuffle_weight_back_by_group_idx (
609
+ const at::Tensor& qweight,
610
+ const std::vector<int64_t >& weight_shape,
611
+ const at::Tensor& g_idx,
612
+ int64_t group_size) {
613
+ at::Tensor out;
614
+ product_dispatcher<
615
+ std::tuple<at::ScalarType, at::ScalarType>,
616
+ std::tuple<
617
+ enumerate_dispatcher<at::ScalarType, at::kChar , at::kByte >,
618
+ enumerate_dispatcher<at::ScalarType, at::kInt , at::kLong >>>::
619
+ call (
620
+ std::make_tuple (qweight.scalar_type (), g_idx.scalar_type ()),
621
+ [&](auto dtype_tuple) {
622
+ auto tensor_dtype = std::get<0 >(dtype_tuple);
623
+ auto g_idx_dtype = std::get<1 >(dtype_tuple);
624
+ using t_cpp_type =
625
+ typename c10::impl::ScalarTypeToCPPType<tensor_dtype>::type;
626
+ using g_cpp_type =
627
+ typename c10::impl::ScalarTypeToCPPType<g_idx_dtype>::type;
628
+ out = woq_shuffle_weight_back_by_group_idx_impl<
629
+ t_cpp_type,
630
+ g_cpp_type>(qweight, weight_shape, g_idx, group_size);
631
+ },
632
+ [](auto dtype_tuple) {
633
+ TORCH_CHECK (
634
+ false , " Unsupported tensor data type for WOQ with g_idx" );
635
+ });
636
+ return out;
637
+ }
638
+
407
639
} // namespace woq_linear
408
640
} // namespace detail
409
641
} // namespace cpu
0 commit comments