@@ -1658,7 +1658,6 @@ void qlinear_woq_affine_impl(
1658
1658
at::Tensor y,
1659
1659
const int qw_type,
1660
1660
int k_splits,
1661
- int num_concats,
1662
1661
int fusion_type,
1663
1662
const TensorList& others_list,
1664
1663
int64_t quant_block_k,
@@ -1681,9 +1680,6 @@ void qlinear_woq_affine_impl(
1681
1680
quant_block_k == 0 ? 1 : (K + quant_block_k - 1 ) / quant_block_k;
1682
1681
1683
1682
TLA_ASSERT (Nb % 16 == 0 , " Nb must be a multiple of 16" );
1684
- TLA_ASSERT (
1685
- num_concats <= 1 || Nc % num_concats == 0 ,
1686
- " Nc must be a multiple of num_concats" );
1687
1683
1688
1684
// select BLOCK_M according to M
1689
1685
// TODO(jgong5): improve the heuristic
@@ -1700,7 +1696,7 @@ void qlinear_woq_affine_impl(
1700
1696
auto BLOCK_M_rem = M % BLOCK_M;
1701
1697
1702
1698
// TODO(jgong5): use heuristics to decide k_splits
1703
- if (k_splits <= 0 || num_concats > 1 || M >= 32 || BLOCK_M_rem) {
1699
+ if (k_splits <= 0 || M >= 32 || BLOCK_M_rem) {
1704
1700
k_splits = 1 ;
1705
1701
}
1706
1702
TLA_ASSERT (Kc % k_splits == 0 , " Kc must be a multiple of k_splits" );
@@ -1713,15 +1709,13 @@ void qlinear_woq_affine_impl(
1713
1709
k_splits == 1 ;
1714
1710
1715
1711
auto lda = no_x_buf ? K : Kb;
1716
- auto ldy = num_concats <= 1 ? N : Nc / num_concats * Nb ;
1712
+ auto ldy = N ;
1717
1713
auto ldc = (no_y_buf || k_splits > 1 ) ? ldy : Nb;
1718
1714
1719
1715
auto px = GetVLAPtr<T>(x, {Kc, Kb});
1720
1716
auto pw = GetVLAPtr<uint8_t >(
1721
1717
(uint8_t *)qw_packed.data_ptr (), {Kc, Kb * (is_4bit_flag ? Nb / 2 : Nb)});
1722
1718
auto py = GetVLAPtr<Tout>(y, {Nc, Nb}); /* [M, Nc, Nb]*/
1723
- auto py_concat = GetVLAPtr<Tout>(
1724
- y, {M, Nc / num_concats, Nb}); /* [num_concats, M, Nc/num_concats, Nb]*/
1725
1719
int scales_kc = quant_w_mode == QUANT_W_PER_CHANNEL ? QUANT_W_PER_K_BLOCK
1726
1720
: quant_k_blocks;
1727
1721
auto pscales = GetVLAPtr<TScale>(scales, {scales_kc, Nb});
@@ -1730,12 +1724,8 @@ void qlinear_woq_affine_impl(
1730
1724
auto pb = GetVLAPtr<TGemmOut>(b, {Nb});
1731
1725
auto tin0 = others_list.size () > 0 ? others_list[0 ] : at::Tensor{};
1732
1726
auto pin0 = GetVLAPtr<Tout>(tin0, {Nc, Nb}); /* [M, Nc, Nb]*/
1733
- auto pin0_concat = GetVLAPtr<Tout>(
1734
- tin0, {M, Nc / num_concats, Nb}); /* [num_concats, M, Nc/num_concats, Nb]*/
1735
1727
auto tin1 = others_list.size () > 1 ? others_list[1 ] : at::Tensor{};
1736
1728
auto pin1 = GetVLAPtr<Tout>(tin1, {Nc, Nb}); /* [M, Nc, Nb]*/
1737
- auto pin1_concat = GetVLAPtr<Tout>(
1738
- tin1, {M, Nc / num_concats, Nb}); /* [num_concats, M, Nc/num_concats, Nb]*/
1739
1729
1740
1730
auto copy_bias_out_tpp = CpyBiasTPP<TGemmOut>(BLOCK_M, Nb, ldy);
1741
1731
auto copy_bias_buf_tpp = CpyBiasTPP<TGemmOut>(BLOCK_M, Nb, Nb);
@@ -1754,19 +1744,9 @@ void qlinear_woq_affine_impl(
1754
1744
bool is_fusion_type_addrelated =
1755
1745
fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD;
1756
1746
auto post_ops_fn = [&](int m, int nc) {
1757
- Tout* y_ptr = num_concats <= 1
1758
- ? (Tout*)py[m][nc]
1759
- : (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)];
1760
- Tout* tin0_ptr = is_fusion_type_addrelated ? num_concats <= 1
1761
- ? (Tout*)pin0[m][nc]
1762
- : (Tout*)pin0_concat[nc / (Nc / num_concats)][m]
1763
- [nc % (Nc / num_concats)]
1764
- : nullptr ;
1765
- Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1
1766
- ? (Tout*)pin1[m][nc]
1767
- : (Tout*)pin1_concat[nc / (Nc / num_concats)][m]
1768
- [nc % (Nc / num_concats)]
1769
- : nullptr ;
1747
+ Tout* y_ptr = (Tout*)py[m][nc];
1748
+ Tout* tin0_ptr = is_fusion_type_addrelated ? (Tout*)pin0[m][nc] : nullptr ;
1749
+ Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr ;
1770
1750
if (fusion_type == FUSE_GELU_ERF) {
1771
1751
gelu_erf_fwd_tpp (y_ptr, y_ptr);
1772
1752
} else if (fusion_type == FUSE_ADD) {
@@ -1779,19 +1759,11 @@ void qlinear_woq_affine_impl(
1779
1759
}
1780
1760
};
1781
1761
auto post_ops_rem_fn = [&](int m, int nc) {
1782
- Tout* y_ptr = num_concats <= 1
1783
- ? (Tout*)py[m][nc]
1784
- : (Tout*)py_concat[nc / (Nc / num_concats)][m][nc % (Nc / num_concats)];
1762
+ Tout* y_ptr = (Tout*)py[m][nc];
1785
1763
Tout* tin0_ptr = (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD)
1786
- ? num_concats <= 1 ? (Tout*)pin0[m][nc]
1787
- : (Tout*)pin0_concat[nc / (Nc / num_concats)][m]
1788
- [nc % (Nc / num_concats)]
1764
+ ? (Tout*)pin0[m][nc]
1789
1765
: nullptr ;
1790
- Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? num_concats <= 1
1791
- ? (Tout*)pin1[m][nc]
1792
- : (Tout*)pin1_concat[nc / (Nc / num_concats)][m]
1793
- [nc % (Nc / num_concats)]
1794
- : nullptr ;
1766
+ Tout* tin1_ptr = fusion_type == FUSE_ADD_ADD ? (Tout*)pin1[m][nc] : nullptr ;
1795
1767
if (fusion_type == FUSE_GELU_ERF) {
1796
1768
gelu_erf_fwd_rem_tpp (y_ptr, y_ptr);
1797
1769
} else if (fusion_type == FUSE_ADD) {
@@ -1961,10 +1933,7 @@ void qlinear_woq_affine_impl(
1961
1933
}
1962
1934
}
1963
1935
bool is_rem = (m + BLOCK_M > M);
1964
- TGemmOut* y_ptr = num_concats <= 1
1965
- ? (TGemmOut*)py[m][nc]
1966
- : (TGemmOut*)py_concat[nc / (Nc / num_concats)][m]
1967
- [nc % (Nc / num_concats)];
1936
+ TGemmOut* y_ptr = (TGemmOut*)py[m][nc];
1968
1937
if (!is_rem) {
1969
1938
if (kc == 0 ) {
1970
1939
if (b.defined ()) {
@@ -2073,10 +2042,7 @@ void qlinear_woq_affine_impl(
2073
2042
int kc_end = kc_start + Kc / k_splits;
2074
2043
int m = idx[2 ];
2075
2044
bool is_rem = (m + BLOCK_M > M);
2076
- auto y_out_ptr = num_concats <= 1
2077
- ? py[m][nc]
2078
- : py_concat[nc / (Nc / num_concats)][m]
2079
- [nc % (Nc / num_concats)];
2045
+ auto y_out_ptr = py[m][nc];
2080
2046
alignas (64 ) TGemmOut y_buf[BLOCK_M][Nb];
2081
2047
TGemmOut* y_ptr = y_private_ptr[my_id][m][nc];
2082
2048
if (k_splits > 1 ) {
@@ -3389,7 +3355,6 @@ at::Tensor qlinear_woq_affine(
3389
3355
const TensorList& bias_list,
3390
3356
const int qw_type,
3391
3357
int64_t lowp_mode,
3392
- int64_t num_concats,
3393
3358
int64_t fusion_type,
3394
3359
const TensorList& others_list,
3395
3360
int64_t quant_a_mode = -1 ,
@@ -3449,7 +3414,6 @@ at::Tensor qlinear_woq_affine(
3449
3414
y,
3450
3415
qw_type,
3451
3416
k_splits,
3452
- num_concats,
3453
3417
fusion_type,
3454
3418
others_list,
3455
3419
quant_block_k);
@@ -3470,7 +3434,6 @@ at::Tensor qlinear_woq_affine(
3470
3434
y,
3471
3435
qw_type,
3472
3436
k_splits,
3473
- num_concats,
3474
3437
fusion_type,
3475
3438
others_list,
3476
3439
quant_block_k,
@@ -3494,7 +3457,6 @@ at::Tensor qlinear_woq_affine(
3494
3457
y,
3495
3458
qw_type,
3496
3459
k_splits,
3497
- num_concats,
3498
3460
fusion_type,
3499
3461
others_list,
3500
3462
quant_block_k);
@@ -3515,7 +3477,6 @@ at::Tensor qlinear_woq_affine(
3515
3477
y,
3516
3478
qw_type,
3517
3479
k_splits,
3518
- num_concats,
3519
3480
fusion_type,
3520
3481
others_list,
3521
3482
quant_block_k,
@@ -3544,7 +3505,6 @@ at::Tensor qlinear_woq_affine(
3544
3505
y,
3545
3506
qw_type,
3546
3507
k_splits,
3547
- num_concats,
3548
3508
fusion_type,
3549
3509
others_list,
3550
3510
quant_block_k);
@@ -3565,7 +3525,6 @@ at::Tensor qlinear_woq_affine(
3565
3525
y,
3566
3526
qw_type,
3567
3527
k_splits,
3568
- num_concats,
3569
3528
fusion_type,
3570
3529
others_list,
3571
3530
quant_block_k,
@@ -3589,7 +3548,6 @@ at::Tensor qlinear_woq_affine(
3589
3548
y,
3590
3549
qw_type,
3591
3550
k_splits,
3592
- num_concats,
3593
3551
fusion_type,
3594
3552
others_list,
3595
3553
quant_block_k);
@@ -3610,7 +3568,6 @@ at::Tensor qlinear_woq_affine(
3610
3568
y,
3611
3569
qw_type,
3612
3570
k_splits,
3613
- num_concats,
3614
3571
fusion_type,
3615
3572
others_list,
3616
3573
quant_block_k,
@@ -3639,7 +3596,6 @@ at::Tensor qlinear_woq_affine(
3639
3596
y,
3640
3597
qw_type,
3641
3598
k_splits,
3642
- num_concats,
3643
3599
fusion_type,
3644
3600
others_list,
3645
3601
quant_block_k);
@@ -3660,7 +3616,6 @@ at::Tensor qlinear_woq_affine(
3660
3616
y,
3661
3617
qw_type,
3662
3618
k_splits,
3663
- num_concats,
3664
3619
fusion_type,
3665
3620
others_list,
3666
3621
quant_block_k,
@@ -3697,7 +3652,6 @@ at::Tensor qlinear_woq_affine(
3697
3652
y,
3698
3653
qw_type,
3699
3654
k_splits,
3700
- num_concats,
3701
3655
fusion_type,
3702
3656
others_list,
3703
3657
quant_block_k,
@@ -3738,7 +3692,6 @@ at::Tensor qlinear_woq_affine(
3738
3692
y,
3739
3693
qw_type,
3740
3694
k_splits,
3741
- num_concats,
3742
3695
fusion_type,
3743
3696
others_list,
3744
3697
quant_block_k,
@@ -3858,12 +3811,6 @@ at::Tensor qlinear_woq_affine(
3858
3811
: bf16_idx;
3859
3812
y = at::add (y, biases[b_index]);
3860
3813
}
3861
- if (num_concats > 1 ) {
3862
- y = y.view ({-1 , num_concats, y.size (-1 ) / num_concats})
3863
- .transpose (0 , 1 )
3864
- .contiguous ()
3865
- .view ({-1 , y.size (-1 )});
3866
- }
3867
3814
if (fusion_type == FUSE_GELU_ERF) {
3868
3815
y = at::gelu (y);
3869
3816
} else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) {
@@ -3892,7 +3839,6 @@ at::Tensor qlinear_woq_affine(
3892
3839
const TensorList& bias_list,
3893
3840
const int qw_type,
3894
3841
int64_t lowp_mode,
3895
- int64_t num_concats,
3896
3842
int64_t fusion_type,
3897
3843
const TensorList& others_list,
3898
3844
int64_t quant_a_mode = -1 ,
@@ -4007,12 +3953,6 @@ at::Tensor qlinear_woq_affine(
4007
3953
: bf16_idx;
4008
3954
y = at::add (y, biases[b_index]);
4009
3955
}
4010
- if (num_concats > 1 ) {
4011
- y = y.view ({-1 , num_concats, y.size (-1 ) / num_concats})
4012
- .transpose (0 , 1 )
4013
- .contiguous ()
4014
- .view ({-1 , y.size (-1 )});
4015
- }
4016
3956
if (fusion_type == FUSE_GELU_ERF) {
4017
3957
y = at::gelu (y);
4018
3958
} else if (fusion_type == FUSE_ADD || fusion_type == FUSE_ADD_ADD) {
0 commit comments