Skip to content

Commit be63628

Browse files
authored
WOQ: Support g_idx (#2550)
* WOQ: Support g_idx by shuffling weight ahead of time and shuffling activation at runtime * Add UT for compute with g_idx * Unpack int4 weight with g_idx * Fix clang format issue * Use enumerate_dispatcher in woq_shuffle_tensor_by_group_idx * Define a function for shuffling for unpack; support g_idx serialization; move gptq UT to nightly * Separate WoqLinearOpContext::to_public and to_public_no_shuffle * Disable concat linear if act_order is used * Fix clang-format & flake8 issues * Turn off act_order in run_gptq.py by default * Fix UT failure with old ISA * fix woq unpack with g_idx UT failure with old ISA * Fix deepspeed UT failures * Add a helper function to check g_idx and shuffle input
1 parent b8a2bc7 commit be63628

File tree

12 files changed

+614
-145
lines changed

12 files changed

+614
-145
lines changed

csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ struct ContextLinearWoq final {
99
at::Tensor at_weight_;
1010
std::vector<int64_t> weight_shape_;
1111
c10::optional<at::Tensor> at_bias_;
12+
c10::optional<at::Tensor> g_idx_;
1213
// The list contains three dtype versions of bias, scale and zp
1314
// i.e., fp32, fp16, bf16
1415
// If bias is not present, it contains empty tensors
@@ -29,6 +30,7 @@ struct ContextLinearWoq final {
2930
at::Tensor&& scales_float,
3031
at::Tensor&& zero_point_float,
3132
c10::optional<at::Tensor>&& bias,
33+
c10::optional<at::Tensor>&& g_idx,
3234
bool is_int4 = false,
3335
int64_t group_size = -1,
3436
int64_t lowp_mode = 0,
@@ -37,6 +39,7 @@ struct ContextLinearWoq final {
3739
: at_weight_(std::move(at_weight)),
3840
weight_shape_(std::move(weight_shape)),
3941
at_bias_(std::move(bias)),
42+
g_idx_(std::move(g_idx)),
4043
is_int4_(is_int4),
4144
group_size_(group_size),
4245
lowp_mode_(lowp_mode),

csrc/cpu/jit/cpu/kernels/LinearWoqPacked.cpp

Lines changed: 237 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <ideep.hpp>
44
#include "aten/Linear.h"
55
#include "aten/WeightPack.h"
6+
#include "csrc/cpu/tpp/woq/tla.h"
67
#include "ideep/IDeepConversions.h"
78

89
namespace torch_ipex {
@@ -16,6 +17,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContext(
1617
at::Tensor&& scales,
1718
at::Tensor&& zero_points,
1819
c10::optional<at::Tensor>&& bias,
20+
c10::optional<at::Tensor>&& g_idx,
1921
c10::optional<int64_t> batch_size,
2022
bool is_int4,
2123
int64_t group_size,
@@ -32,6 +34,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContext(
3234
std::move(scales),
3335
std::move(zero_points),
3436
std::move(bias),
37+
std::move(g_idx),
3538
batch_size,
3639
is_int4,
3740
group_size,
@@ -45,6 +48,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContextInt4(
4548
at::Tensor&& scales,
4649
at::Tensor&& zero_points,
4750
c10::optional<at::Tensor>&& bias,
51+
c10::optional<at::Tensor>&& g_idx,
4852
c10::optional<int64_t> batch_size,
4953
int64_t group_size, // group_size along input channel
5054
int64_t lowp_mode,
@@ -160,6 +164,7 @@ c10::intrusive_ptr<WoqLinearOpContext> createWoqLinearPrePackOpContextInt4(
160164
std::move(scales_fp32),
161165
std::move(zp_fp32),
162166
std::move(bias),
167+
std::move(g_idx),
163168
batch_size,
164169
/*is_int4*/ true,
165170
group_size,
@@ -183,17 +188,29 @@ ContextLinearWoq create(
183188
at::Tensor& scales,
184189
at::Tensor& zero_points,
185190
const c10::optional<at::Tensor>& bias,
191+
const c10::optional<at::Tensor>& g_idx,
186192
const c10::optional<int64_t> batch_size,
187193
bool is_int4,
188194
int64_t group_size,
189195
int64_t lowp_mode,
190196
int64_t num_concats,
191197
int64_t act_quant_mode) {
192-
auto packed_weight = woq_linear_pack_weight(
193-
weight, weight_shape, is_int4, group_size, lowp_mode);
198+
at::Tensor packed_weight;
199+
int64_t N = weight_shape[0];
200+
int64_t K = weight_shape[1];
201+
// GPTQ with act-order
202+
// Shuffle weight along ic to make channels contiguous in group
203+
if (is_int4 && group_size > 0 && g_idx.has_value()) {
204+
// Shuffle weight along ic to make channels contiguous in group
205+
auto shuffled_weight = woq_shuffle_tensor_by_group_idx</* is_int4 */ true>(
206+
weight, weight_shape, g_idx.value(), group_size);
207+
packed_weight = woq_linear_pack_weight(
208+
shuffled_weight, weight_shape, is_int4, group_size, lowp_mode);
209+
} else {
210+
packed_weight = woq_linear_pack_weight(
211+
weight, weight_shape, is_int4, group_size, lowp_mode);
212+
}
194213
auto packed_shape = packed_weight.sizes();
195-
int64_t N = weight.size(0);
196-
int64_t K = weight.size(1);
197214
// If OC is not a multiple of BLOCK_N, it may be padded.
198215
bool oc_is_padded = (packed_shape.size() == 4 && is_int4 &&
199216
packed_shape[0] * packed_shape[3] * 2 != N) ||
@@ -221,6 +238,7 @@ ContextLinearWoq create(
221238
std::move(scales_padded),
222239
std::move(zero_points_padded),
223240
c10::make_optional(bias_padded),
241+
g_idx.has_value() ? c10::make_optional(*g_idx) : c10::nullopt,
224242
is_int4,
225243
group_size,
226244
lowp_mode,
@@ -233,6 +251,7 @@ ContextLinearWoq create(
233251
std::move(scales_padded),
234252
std::move(zero_points_padded),
235253
c10::nullopt,
254+
g_idx.has_value() ? c10::make_optional(*g_idx) : c10::nullopt,
236255
is_int4,
237256
group_size,
238257
lowp_mode,
@@ -246,13 +265,30 @@ ContextLinearWoq create(
246265
std::move(scales),
247266
std::move(zero_points_float),
248267
bias.has_value() ? c10::make_optional(*bias) : c10::nullopt,
268+
g_idx.has_value() ? c10::make_optional(*g_idx) : c10::nullopt,
249269
is_int4,
250270
group_size,
251271
lowp_mode,
252272
num_concats,
253273
act_quant_mode);
254274
}
255275

276+
static at::Tensor _shuffle_input_channels_if_needed(
277+
ContextLinearWoq& context,
278+
const at::Tensor& input) {
279+
// GPTQ with act-order
280+
// Shuffle input channels to align with weight
281+
if (context.is_int4_ && context.group_size_ > 0 &&
282+
context.g_idx_.has_value()) {
283+
auto& g_idx = context.g_idx_.value();
284+
auto K = input.size(-1);
285+
std::vector<int64_t> input_shape = {input.numel() / K, K};
286+
return woq_shuffle_tensor_by_group_idx(
287+
input, input_shape, g_idx, context.group_size_);
288+
}
289+
return input;
290+
}
291+
256292
at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) {
257293
// TPP kernel packs weight to 4d (Nc, Kc, block_k, block_n)
258294
auto w_k = context.weight_shape_[1];
@@ -264,6 +300,8 @@ at::Tensor run(ContextLinearWoq& context, const at::Tensor& input) {
264300
w_k,
265301
" respectively.");
266302
auto input_ = input.contiguous();
303+
// handle GPTQ with act-order
304+
input_ = _shuffle_input_channels_if_needed(context, input_);
267305
auto res = woq_linear_kernel(
268306
input_,
269307
context.at_weight_,
@@ -299,6 +337,8 @@ at::Tensor run_eltwise(
299337
w_k,
300338
" respectively.");
301339
auto input_ = input.contiguous();
340+
// handle GPTQ with act-order
341+
input_ = _shuffle_input_channels_if_needed(context, input_);
302342
return woq_linear_eltwise_kernel(
303343
input_,
304344
context.at_weight_,
@@ -330,6 +370,8 @@ at::Tensor run_add(
330370
w_k,
331371
" respectively.");
332372
auto input_ = input.contiguous();
373+
// handle GPTQ with act-order
374+
input_ = _shuffle_input_channels_if_needed(context, input_);
333375
return woq_linear_add_kernel(
334376
input_,
335377
context.at_weight_,
@@ -359,6 +401,8 @@ at::Tensor run_add_add(
359401
w_k,
360402
" respectively.");
361403
auto input_ = input.contiguous();
404+
// handle GPTQ with act-order
405+
input_ = _shuffle_input_channels_if_needed(context, input_);
362406
return woq_linear_add_add_kernel(
363407
input_,
364408
context.at_weight_,
@@ -380,10 +424,19 @@ at::Tensor pack(ContextLinearWoq& context, const at::Tensor& tensor) {
380424
at::Tensor unpack(ContextLinearWoq& context, const at::Tensor& tensor) {
381425
// By using different kernels, the packed weight dim can be 2 or 4
382426
// Return result directly if dim == 2
383-
// For dim == 4, make a new quantized tensor and return.
427+
// For dim == 4, weight may be padded.
384428
// For padded weight (int4), make a slice of it.
385429
auto unpacked_weight =
386430
woq_linear_unpack_weight(tensor, context.is_int4_, context.lowp_mode_);
431+
// With g_idx, weight's input channels are shuffled along ic so that
432+
// those in the same group are contiguous.
433+
// Here we need to shuffle them to the original order.
434+
if (context.group_size_ > 0 && context.g_idx_.has_value()) {
435+
auto group_size = context.group_size_;
436+
auto& g_idx = context.g_idx_.value();
437+
unpacked_weight = woq_shuffle_weight_back_by_group_idx(
438+
unpacked_weight, context.weight_shape_, g_idx, group_size);
439+
}
387440
if (tensor.dim() > 2) {
388441
auto scales = context.scales_list_[0];
389442
auto zero_points = context.zero_points_list_[0];
@@ -404,6 +457,185 @@ at::Tensor unpack(ContextLinearWoq& context, const at::Tensor& tensor) {
404457
return unpacked_weight;
405458
}
406459

460+
template <typename T, typename Tg, bool is_int4 = false>
461+
at::Tensor woq_shuffle_tensor_by_group_idx_impl(
462+
const at::Tensor& tensor,
463+
const std::vector<int64_t>& tensor_shape,
464+
const at::Tensor& g_idx,
465+
int64_t group_size) {
466+
// g_idx shape = [ic]
467+
// i-th element indicates which group tensor[:][i] belongs to.
468+
// Shuffle tensor along ic to make channels contiguous in group.
469+
int64_t N = tensor_shape[0];
470+
int64_t K = tensor_shape[1];
471+
auto shuffled_tensor = at::zeros_like(tensor, tensor.dtype());
472+
auto shuffled_tensor_data = reinterpret_cast<T*>(shuffled_tensor.data_ptr());
473+
auto tensor_data = reinterpret_cast<T*>(tensor.data_ptr());
474+
auto num_groups = (K + group_size - 1) / group_size;
475+
auto g_idx_data = reinterpret_cast<Tg*>(g_idx.data_ptr());
476+
#pragma omp parallel for
477+
for (int64_t i = 0; i < N; ++i) {
478+
std::vector<int64_t> counts_per_group(num_groups, 0);
479+
auto stride = is_int4 ? K / 2 : K;
480+
auto tensor_row_data = tensor_data + i * stride;
481+
auto shuffled_row_data = shuffled_tensor_data + i * stride;
482+
for (int64_t j = 0; j < K; ++j) {
483+
auto g = g_idx_data[j];
484+
auto new_idx = g * group_size + counts_per_group[g];
485+
constexpr bool T_is_int8 =
486+
std::is_same<T, int8_t>() || std::is_same<T, uint8_t>();
487+
if constexpr (is_int4 && T_is_int8) {
488+
uint8_t mask = j % 2 ? 0xF0 : 0x0F;
489+
size_t rshift = j % 2 ? 4 : 0;
490+
T data = (tensor_row_data[j / 2] & mask) >> rshift;
491+
shuffled_row_data[new_idx / 2] =
492+
shuffled_row_data[new_idx / 2] | (new_idx % 2 ? (data << 4) : data);
493+
} else {
494+
T data = tensor_row_data[j];
495+
shuffled_row_data[new_idx] = data;
496+
}
497+
++counts_per_group[g];
498+
}
499+
}
500+
return shuffled_tensor;
501+
}
502+
503+
/**
504+
* Shuffle activation or weight tensor along input channel according to group
505+
* index (g_idx), so that input channels in the same group are contiguous to
506+
* each other.
507+
*
508+
* @param is_int4 The tensor stores int4 data or not
509+
* @param tensor The tensor to be shuffled. It must be 2d.
510+
* @param tensor_shape The original shape of the tensor. It is different from
511+
* tensor.shape() when dtype is int4 since 2 int4 data are packed as one int8.
512+
* @param g_idx The g_idx tensor contains group index for each input channel.
513+
* Its shape is [number of input channels]. Indices should be in [0, number of
514+
* groups).
515+
* @param group_size The group size of input channels. Used to determine number
516+
* of groups.
517+
* @return The shuffled tensor.
518+
*/
519+
template <bool is_int4>
520+
at::Tensor woq_shuffle_tensor_by_group_idx(
521+
const at::Tensor& tensor,
522+
const std::vector<int64_t>& tensor_shape,
523+
const at::Tensor& g_idx,
524+
int64_t group_size) {
525+
at::Tensor out;
526+
product_dispatcher<
527+
std::tuple<at::ScalarType, at::ScalarType>,
528+
std::tuple<
529+
enumerate_dispatcher<
530+
at::ScalarType,
531+
at::kDouble,
532+
at::kFloat,
533+
at::kBFloat16,
534+
at::kHalf,
535+
at::kChar,
536+
at::kByte>,
537+
enumerate_dispatcher<at::ScalarType, at::kInt, at::kLong>>>::
538+
call(
539+
std::make_tuple(tensor.scalar_type(), g_idx.scalar_type()),
540+
[&](auto dtype_tuple) {
541+
auto tensor_dtype = std::get<0>(dtype_tuple);
542+
auto g_idx_dtype = std::get<1>(dtype_tuple);
543+
using t_cpp_type =
544+
typename c10::impl::ScalarTypeToCPPType<tensor_dtype>::type;
545+
using g_cpp_type =
546+
typename c10::impl::ScalarTypeToCPPType<g_idx_dtype>::type;
547+
out = woq_shuffle_tensor_by_group_idx_impl<
548+
t_cpp_type,
549+
g_cpp_type,
550+
is_int4>(tensor, tensor_shape, g_idx, group_size);
551+
},
552+
[](auto dtype_tuple) {
553+
TORCH_CHECK(
554+
false, "Unsupported tensor data type for WOQ with g_idx");
555+
});
556+
return out;
557+
}
558+
559+
template <typename T, typename Tg>
560+
at::Tensor woq_shuffle_weight_back_by_group_idx_impl(
561+
const at::Tensor& qweight,
562+
const std::vector<int64_t>& weight_shape,
563+
const at::Tensor& g_idx,
564+
int64_t group_size) {
565+
auto N = weight_shape[0];
566+
auto K = weight_shape[1];
567+
auto shuffled_tensor = at::zeros_like(qweight, qweight.dtype());
568+
auto shuffled_tensor_data = reinterpret_cast<T*>(shuffled_tensor.data_ptr());
569+
auto tensor_data = reinterpret_cast<T*>(qweight.data_ptr());
570+
auto num_groups = (K + group_size - 1) / group_size;
571+
auto g_idx_data = reinterpret_cast<Tg*>(g_idx.data_ptr());
572+
#pragma omp parallel for
573+
for (int64_t i = 0; i < N; ++i) {
574+
std::vector<int64_t> counts_per_group(num_groups, 0);
575+
auto stride = K / 2;
576+
auto tensor_row_data = tensor_data + i * stride;
577+
auto shuffled_row_data = shuffled_tensor_data + i * stride;
578+
for (int64_t j = 0; j < K; ++j) {
579+
auto g = g_idx_data[j];
580+
T* data_pos =
581+
tensor_row_data + g * group_size / 2 + counts_per_group[g] / 2;
582+
uint8_t mask = counts_per_group[g] % 2 ? 0xF0 : 0x0F;
583+
size_t rshift = counts_per_group[g] % 2 ? 4 : 0;
584+
T data = (*data_pos & mask) >> rshift;
585+
shuffled_row_data[j / 2] =
586+
shuffled_row_data[j / 2] | (j % 2 ? (data << 4) : data);
587+
++counts_per_group[g];
588+
}
589+
}
590+
return shuffled_tensor;
591+
}
592+
593+
/**
594+
* Shuffle weight tensor along input channel according to group index (g_idx)
595+
* to its original order. It is used for unpacking weight. Data type is assumed
596+
* INT4.
597+
*
598+
* @param qweight The weight to be shuffled. It must be 2d.
599+
* @param weight_shape The original shape of the weight. It is different from
600+
* tensor.shape() since 2 int4 data are packed as one int8.
601+
* @param g_idx The g_idx tensor contains group index for each input channel.
602+
* Its shape is [number of input channels]. Indices should be in [0, number of
603+
* groups).
604+
* @param group_size The group size of input channels. Used to determine number
605+
* of groups.
606+
* @return The shuffled tensor.
607+
*/
608+
at::Tensor woq_shuffle_weight_back_by_group_idx(
609+
const at::Tensor& qweight,
610+
const std::vector<int64_t>& weight_shape,
611+
const at::Tensor& g_idx,
612+
int64_t group_size) {
613+
at::Tensor out;
614+
product_dispatcher<
615+
std::tuple<at::ScalarType, at::ScalarType>,
616+
std::tuple<
617+
enumerate_dispatcher<at::ScalarType, at::kChar, at::kByte>,
618+
enumerate_dispatcher<at::ScalarType, at::kInt, at::kLong>>>::
619+
call(
620+
std::make_tuple(qweight.scalar_type(), g_idx.scalar_type()),
621+
[&](auto dtype_tuple) {
622+
auto tensor_dtype = std::get<0>(dtype_tuple);
623+
auto g_idx_dtype = std::get<1>(dtype_tuple);
624+
using t_cpp_type =
625+
typename c10::impl::ScalarTypeToCPPType<tensor_dtype>::type;
626+
using g_cpp_type =
627+
typename c10::impl::ScalarTypeToCPPType<g_idx_dtype>::type;
628+
out = woq_shuffle_weight_back_by_group_idx_impl<
629+
t_cpp_type,
630+
g_cpp_type>(qweight, weight_shape, g_idx, group_size);
631+
},
632+
[](auto dtype_tuple) {
633+
TORCH_CHECK(
634+
false, "Unsupported tensor data type for WOQ with g_idx");
635+
});
636+
return out;
637+
}
638+
407639
} // namespace woq_linear
408640
} // namespace detail
409641
} // namespace cpu

0 commit comments

Comments
 (0)