Skip to content

Commit adb5638

Browse files
authored
WOQ: Remove concat-linear implementation from kernel (#2617)
* Remove num_concat for woq in the frontend to align with bf16 tpp linear * Remove num_concat in woq linear kernel * Fix clang-format issue
1 parent d6130df commit adb5638

File tree

14 files changed

+28
-201
lines changed

14 files changed

+28
-201
lines changed

csrc/cpu/aten/Linear.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ at::Tensor woq_linear_kernel(
429429
bool is_int4,
430430
int64_t group_size,
431431
int64_t lowp_mode,
432-
int64_t num_concats,
433432
int64_t act_quant_mode) {
434433
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
435434
int64_t quant_w_mode = group_size > 0 ? 1 : 0;
@@ -442,7 +441,6 @@ at::Tensor woq_linear_kernel(
442441
bias_list,
443442
w_dtype,
444443
lowp_mode,
445-
num_concats,
446444
WOQ_FUSE_NONE, // no post op fusion
447445
std::vector<at::Tensor>(),
448446
act_quant_mode,
@@ -472,7 +470,6 @@ at::Tensor woq_linear_eltwise_kernel(
472470
bool is_int4,
473471
int64_t group_size,
474472
int64_t lowp_mode,
475-
int64_t num_concats,
476473
int64_t act_quant_mode) {
477474
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
478475
int64_t post_op_fusion_type = WOQ_FUSE_NONE;
@@ -493,7 +490,6 @@ at::Tensor woq_linear_eltwise_kernel(
493490
bias_list,
494491
w_dtype,
495492
lowp_mode,
496-
num_concats,
497493
post_op_fusion_type,
498494
std::vector<at::Tensor>(),
499495
act_quant_mode,
@@ -532,7 +528,6 @@ at::Tensor woq_linear_add_kernel(
532528
bool is_int4,
533529
int64_t group_size,
534530
int64_t lowp_mode,
535-
int64_t num_concats,
536531
const std::vector<at::Tensor>& others,
537532
int64_t act_quant_mode) {
538533
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
@@ -546,7 +541,6 @@ at::Tensor woq_linear_add_kernel(
546541
bias_list,
547542
w_dtype,
548543
lowp_mode,
549-
num_concats,
550544
WOQ_FUSE_ADD, // post op add
551545
others,
552546
act_quant_mode,
@@ -563,7 +557,6 @@ at::Tensor woq_linear_add_add_kernel(
563557
bool is_int4,
564558
int64_t group_size,
565559
int64_t lowp_mode,
566-
int64_t num_concats,
567560
const std::vector<at::Tensor>& others,
568561
int64_t act_quant_mode) {
569562
int w_dtype = is_int4 ? WOQ_DTYPE_QINT4 : WOQ_DTYPE_QINT8;
@@ -577,7 +570,6 @@ at::Tensor woq_linear_add_add_kernel(
577570
bias_list,
578571
w_dtype,
579572
lowp_mode,
580-
num_concats,
581573
WOQ_FUSE_ADD_ADD, // post op add-add
582574
others,
583575
act_quant_mode,

csrc/cpu/aten/Linear.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ at::Tensor woq_linear_kernel(
105105
bool is_int4,
106106
int64_t group_size,
107107
int64_t lowp_mode,
108-
int64_t num_concats,
109108
int64_t act_quant_mode);
110109

111110
at::Tensor woq_linear_eltwise_kernel(
@@ -120,7 +119,6 @@ at::Tensor woq_linear_eltwise_kernel(
120119
bool is_int4,
121120
int64_t group_size,
122121
int64_t lowp_mode,
123-
int64_t num_concats,
124122
int64_t act_quant_mode);
125123

126124
at::Tensor woq_linear_add_kernel(
@@ -132,7 +130,6 @@ at::Tensor woq_linear_add_kernel(
132130
bool is_int4,
133131
int64_t group_size,
134132
int64_t lowp_mode,
135-
int64_t num_concats,
136133
const std::vector<at::Tensor>& others,
137134
int64_t act_quant_mode);
138135

@@ -145,7 +142,6 @@ at::Tensor woq_linear_add_add_kernel(
145142
bool is_int4,
146143
int64_t group_size,
147144
int64_t lowp_mode,
148-
int64_t num_concats,
149145
const std::vector<at::Tensor>& others,
150146
int64_t act_quant_mode);
151147

@@ -220,7 +216,6 @@ using woq_tpp_gemm_kernel_fn = at::Tensor (*)(
220216
const int,
221217
int64_t,
222218
int64_t,
223-
int64_t,
224219
const std::vector<at::Tensor>&,
225220
int64_t,
226221
int64_t,

csrc/cpu/aten/kernels/WoqTppKrnl.cpp

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1658,7 +1658,6 @@ void qlinear_woq_affine_impl(
16581658
at::Tensor y,
16591659
const int qw_type,
16601660
int k_splits,
1661-
int num_concats,
16621661
int fusion_type,
16631662
const TensorList& others_list,
16641663
int64_t quant_block_k,
@@ -1681,9 +1680,6 @@ void qlinear_woq_affine_impl(
16811680
quant_block_k == 0 ? 1 : (K + quant_block_k - 1) / quant_block_k;
16821681

16831682
TLA_ASSERT(Nb % 16 == 0, "Nb must be a multiple of 16");
1684-
TLA_ASSERT(
1685-
num_concats <= 1 || Nc % num_concats == 0,
1686-
"Nc must be a multiple of num_concats");
16871683

16881684
// select BLOCK_M according to M
16891685
// TODO(jgong5): improve the heuristic
@@ -1700,7 +1696,7 @@ void qlinear_woq_affine_impl(
17001696
auto BLOCK_M_rem = M % BLOCK_M;
17011697

17021698
// TODO(jgong5): use heuristics to decide k_splits
1703-
if (k_splits <= 0 || num_concats > 1 || M >= 32 || BLOCK_M_rem) {
1699+
if (k_splits <= 0 || M >= 32 || BLOCK_M_rem) {
17041700
k_splits = 1;
17051701
}
17061702
TLA_ASSERT(Kc % k_splits == 0, "Kc must be a multiple of k_splits");
@@ -1713,15 +1709,13 @@ void qlinear_woq_affine_impl(
17131709
k_splits == 1;
17141710

17151711
auto lda = no_x_buf ? K : Kb;
1716-
auto ldy = num_concats <= 1 ? N : Nc / num_concats * Nb;
1712+
auto ldy = N;
17171713
auto ldc = (no_y_buf || k_splits > 1) ? ldy : Nb;
17181714

17191715
auto px = GetVLAPtr<T>(x, {Kc, Kb});
17201716
auto pw = GetVLAPtr<uint8_t>(
17211717
(uint8_t*)qw_packed.data_ptr(), {Kc, Kb * (is_4bit_flag ? Nb / 2 : Nb)});
17221718
auto py = GetVLAPtr<Tout>(y, {Nc, Nb}); /*[M, Nc, Nb]*/
1723-
auto py_concat = GetVLAPtr<Tout>(
1724-
y, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/
17251719
int scales_kc = quant_w_mode == QUANT_W_PER_CHANNEL ? QUANT_W_PER_K_BLOCK
17261720
: quant_k_blocks;
17271721
auto pscales = GetVLAPtr<TScale>(scales, {scales_kc, Nb});
@@ -1730,12 +1724,8 @@ void qlinear_woq_affine_impl(
17301724
auto pb = GetVLAPtr<TGemmOut>(b, {Nb});
17311725
auto tin0 = others_list.size() > 0 ? others_list[0] : at::Tensor{};
17321726
auto pin0 = GetVLAPtr<Tout>(tin0, {Nc, Nb}); /*[M, Nc, Nb]*/
1733-
auto pin0_concat = GetVLAPtr<Tout>(
1734-
tin0, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/
17351727
auto tin1 = others_list.size() > 1 ? others_list[1] : at::Tensor{};
17361728
auto pin1 = GetVLAPtr<Tout>(tin1, {Nc, Nb}); /*[M, Nc, Nb]*/
1737-
auto pin1_concat = GetVLAPtr<Tout>(
1738-
tin1, {M, Nc / num_concats, Nb}); /*[num_concats, M, Nc/num_concats, Nb]*/
17391729

17401730
auto copy_bias_out_tpp = CpyBiasTPP<TGemmOut>(BLOCK_M, Nb, ldy);
17411731
auto copy_bias_buf_tpp = CpyBiasTPP<TGemmOut>(BLOCK_M, Nb, Nb);
@@ -1754,19 +1744,9 @@ void qlinear_woq_affine_impl(
17541744
bool is_fusion_type_addrelated =
17551745
fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD;
17561746
auto post_ops_fn = [&](int m, int nc) {
1757-
Tout* y_ptr = num_concats <= 1
1758-
? (Tout*)py[m][nc]
1759-
: (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)];
1760-
Tout* tin0_ptr = is_fusion_type_addrelated ? num_concats <= 1
1761-
? (Tout*)pin0[m][nc]
1762-
: (Tout*)pin0_concat[nc / (Nc / num_concats)][m]
1763-
[nc % (Nc / num_concats)]
1764-
: nullptr;
1765-
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1
1766-
? (Tout*)pin1[m][nc]
1767-
: (Tout*)pin1_concat[nc / (Nc / num_concats)][m]
1768-
[nc % (Nc / num_concats)]
1769-
: nullptr;
1747+
Tout* y_ptr = (Tout*)py[m][nc];
1748+
Tout* tin0_ptr = is_fusion_type_addrelated ? (Tout*)pin0[m][nc] : nullptr;
1749+
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr;
17701750
if (fusion_type == FUSE_GELU_ERF) {
17711751
gelu_erf_fwd_tpp(y_ptr, y_ptr);
17721752
} else if (fusion_type == FUSE_ADD) {
@@ -1779,19 +1759,11 @@ void qlinear_woq_affine_impl(
17791759
}
17801760
};
17811761
auto post_ops_rem_fn = [&](int m, int nc) {
1782-
Tout* y_ptr = num_concats <= 1
1783-
? (Tout*)py[m][nc]
1784-
: (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)];
1762+
Tout* y_ptr = (Tout*)py[m][nc];
17851763
Tout* tin0_ptr = (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD)
1786-
? num_concats <= 1 ? (Tout*)pin0[m][nc]
1787-
: (Tout*)pin0_concat[nc / (Nc / num_concats)][m]
1788-
[nc % (Nc / num_concats)]
1764+
? (Tout*)pin0[m][nc]
17891765
: nullptr;
1790-
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1
1791-
? (Tout*)pin1[m][nc]
1792-
: (Tout*)pin1_concat[nc / (Nc / num_concats)][m]
1793-
[nc % (Nc / num_concats)]
1794-
: nullptr;
1766+
Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr;
17951767
if (fusion_type == FUSE_GELU_ERF) {
17961768
gelu_erf_fwd_rem_tpp(y_ptr, y_ptr);
17971769
} else if (fusion_type == FUSE_ADD) {
@@ -1961,10 +1933,7 @@ void qlinear_woq_affine_impl(
19611933
}
19621934
}
19631935
bool is_rem = (m + BLOCK_M > M);
1964-
TGemmOut* y_ptr = num_concats <= 1
1965-
? (TGemmOut*)py[m][nc]
1966-
: (TGemmOut*)py_concat[nc / (Nc / num_concats)][m]
1967-
[nc % (Nc / num_concats)];
1936+
TGemmOut* y_ptr = (TGemmOut*)py[m][nc];
19681937
if (!is_rem) {
19691938
if (kc == 0) {
19701939
if (b.defined()) {
@@ -2073,10 +2042,7 @@ void qlinear_woq_affine_impl(
20732042
int kc_end = kc_start + Kc / k_splits;
20742043
int m = idx[2];
20752044
bool is_rem = (m + BLOCK_M > M);
2076-
auto y_out_ptr = num_concats <= 1
2077-
? py[m][nc]
2078-
: py_concat[nc / (Nc / num_concats)][m]
2079-
[nc % (Nc / num_concats)];
2045+
auto y_out_ptr = py[m][nc];
20802046
alignas(64) TGemmOut y_buf[BLOCK_M][Nb];
20812047
TGemmOut* y_ptr = y_private_ptr[my_id][m][nc];
20822048
if (k_splits > 1) {
@@ -3389,7 +3355,6 @@ at::Tensor qlinear_woq_affine(
33893355
const TensorList& bias_list,
33903356
const int qw_type,
33913357
int64_t lowp_mode,
3392-
int64_t num_concats,
33933358
int64_t fusion_type,
33943359
const TensorList& others_list,
33953360
int64_t quant_a_mode = -1,
@@ -3449,7 +3414,6 @@ at::Tensor qlinear_woq_affine(
34493414
y,
34503415
qw_type,
34513416
k_splits,
3452-
num_concats,
34533417
fusion_type,
34543418
others_list,
34553419
quant_block_k);
@@ -3470,7 +3434,6 @@ at::Tensor qlinear_woq_affine(
34703434
y,
34713435
qw_type,
34723436
k_splits,
3473-
num_concats,
34743437
fusion_type,
34753438
others_list,
34763439
quant_block_k,
@@ -3494,7 +3457,6 @@ at::Tensor qlinear_woq_affine(
34943457
y,
34953458
qw_type,
34963459
k_splits,
3497-
num_concats,
34983460
fusion_type,
34993461
others_list,
35003462
quant_block_k);
@@ -3515,7 +3477,6 @@ at::Tensor qlinear_woq_affine(
35153477
y,
35163478
qw_type,
35173479
k_splits,
3518-
num_concats,
35193480
fusion_type,
35203481
others_list,
35213482
quant_block_k,
@@ -3544,7 +3505,6 @@ at::Tensor qlinear_woq_affine(
35443505
y,
35453506
qw_type,
35463507
k_splits,
3547-
num_concats,
35483508
fusion_type,
35493509
others_list,
35503510
quant_block_k);
@@ -3565,7 +3525,6 @@ at::Tensor qlinear_woq_affine(
35653525
y,
35663526
qw_type,
35673527
k_splits,
3568-
num_concats,
35693528
fusion_type,
35703529
others_list,
35713530
quant_block_k,
@@ -3589,7 +3548,6 @@ at::Tensor qlinear_woq_affine(
35893548
y,
35903549
qw_type,
35913550
k_splits,
3592-
num_concats,
35933551
fusion_type,
35943552
others_list,
35953553
quant_block_k);
@@ -3610,7 +3568,6 @@ at::Tensor qlinear_woq_affine(
36103568
y,
36113569
qw_type,
36123570
k_splits,
3613-
num_concats,
36143571
fusion_type,
36153572
others_list,
36163573
quant_block_k,
@@ -3639,7 +3596,6 @@ at::Tensor qlinear_woq_affine(
36393596
y,
36403597
qw_type,
36413598
k_splits,
3642-
num_concats,
36433599
fusion_type,
36443600
others_list,
36453601
quant_block_k);
@@ -3660,7 +3616,6 @@ at::Tensor qlinear_woq_affine(
36603616
y,
36613617
qw_type,
36623618
k_splits,
3663-
num_concats,
36643619
fusion_type,
36653620
others_list,
36663621
quant_block_k,
@@ -3697,7 +3652,6 @@ at::Tensor qlinear_woq_affine(
36973652
y,
36983653
qw_type,
36993654
k_splits,
3700-
num_concats,
37013655
fusion_type,
37023656
others_list,
37033657
quant_block_k,
@@ -3738,7 +3692,6 @@ at::Tensor qlinear_woq_affine(
37383692
y,
37393693
qw_type,
37403694
k_splits,
3741-
num_concats,
37423695
fusion_type,
37433696
others_list,
37443697
quant_block_k,
@@ -3858,12 +3811,6 @@ at::Tensor qlinear_woq_affine(
38583811
: bf16_idx;
38593812
y = at::add(y, biases[b_index]);
38603813
}
3861-
if (num_concats > 1) {
3862-
y = y.view({-1, num_concats, y.size(-1) / num_concats})
3863-
.transpose(0, 1)
3864-
.contiguous()
3865-
.view({-1, y.size(-1)});
3866-
}
38673814
if (fusion_type == FUSE_GELU_ERF) {
38683815
y = at::gelu(y);
38693816
} else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) {
@@ -3892,7 +3839,6 @@ at::Tensor qlinear_woq_affine(
38923839
const TensorList& bias_list,
38933840
const int qw_type,
38943841
int64_t lowp_mode,
3895-
int64_t num_concats,
38963842
int64_t fusion_type,
38973843
const TensorList& others_list,
38983844
int64_t quant_a_mode = -1,
@@ -4007,12 +3953,6 @@ at::Tensor qlinear_woq_affine(
40073953
: bf16_idx;
40083954
y = at::add(y, biases[b_index]);
40093955
}
4010-
if (num_concats > 1) {
4011-
y = y.view({-1, num_concats, y.size(-1) / num_concats})
4012-
.transpose(0, 1)
4013-
.contiguous()
4014-
.view({-1, y.size(-1)});
4015-
}
40163956
if (fusion_type == FUSE_GELU_ERF) {
40173957
y = at::gelu(y);
40183958
} else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) {

csrc/cpu/jit/cpu/kernels/ContextLinearWoq.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ struct ContextLinearWoq final {
1919
bool is_int4_;
2020
int64_t group_size_;
2121
int64_t lowp_mode_;
22-
int64_t num_concats_;
2322
int64_t act_quant_mode_;
2423

2524
ContextLinearWoq() = delete;
@@ -34,7 +33,6 @@ struct ContextLinearWoq final {
3433
bool is_int4 = false,
3534
int64_t group_size = -1,
3635
int64_t lowp_mode = 0,
37-
int64_t num_concats = 1,
3836
int64_t act_quant_mode = 0)
3937
: at_weight_(std::move(at_weight)),
4038
weight_shape_(std::move(weight_shape)),
@@ -43,7 +41,6 @@ struct ContextLinearWoq final {
4341
is_int4_(is_int4),
4442
group_size_(group_size),
4543
lowp_mode_(lowp_mode),
46-
num_concats_(num_concats),
4744
act_quant_mode_(act_quant_mode) {
4845
// Make three dtype versions of scale, zp and bias
4946
// There is one more dtype for zp

0 commit comments

Comments
 (0)