Skip to content

Commit fd642dd

Browse files
committed
[RISCV] Combine (or disjoint ext, ext) -> vwadd
DAGCombiner (or InstCombine) will convert an add to an or if the bits are disjoint, which can prevent what was originally an (add {s,z}ext, {s,z}ext) from being selected as a vwadd. This teaches combineBinOp_VLToVWBinOp_VL to recover it by treating it as an add.
1 parent 856e815 commit fd642dd

File tree

2 files changed

+24
-13
lines changed

2 files changed

+24
-13
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13527,7 +13527,7 @@ struct CombineResult;
1352713527
enum ExtKind : uint8_t { ZExt = 1 << 0, SExt = 1 << 1, FPExt = 1 << 2 };
1352813528
/// Helper class for folding sign/zero extensions.
1352913529
/// In particular, this class is used for the following combines:
13530-
/// add | add_vl -> vwadd(u) | vwadd(u)_w
13530+
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
1353113531
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
1353213532
/// mul | mul_vl -> vwmul(u) | vwmul_su
1353313533
/// fadd -> vfwadd | vfwadd_w
@@ -13675,6 +13675,7 @@ struct NodeExtensionHelper {
1367513675
case RISCVISD::ADD_VL:
1367613676
case RISCVISD::VWADD_W_VL:
1367713677
case RISCVISD::VWADDU_W_VL:
13678+
case ISD::OR:
1367813679
return RISCVISD::VWADD_VL;
1367913680
case ISD::SUB:
1368013681
case RISCVISD::SUB_VL:
@@ -13697,6 +13698,7 @@ struct NodeExtensionHelper {
1369713698
case RISCVISD::ADD_VL:
1369813699
case RISCVISD::VWADD_W_VL:
1369913700
case RISCVISD::VWADDU_W_VL:
13701+
case ISD::OR:
1370013702
return RISCVISD::VWADDU_VL;
1370113703
case ISD::SUB:
1370213704
case RISCVISD::SUB_VL:
@@ -13742,6 +13744,7 @@ struct NodeExtensionHelper {
1374213744
switch (Opcode) {
1374313745
case ISD::ADD:
1374413746
case RISCVISD::ADD_VL:
13747+
case ISD::OR:
1374513748
return SupportsExt == ExtKind::SExt ? RISCVISD::VWADD_W_VL
1374613749
: RISCVISD::VWADDU_W_VL;
1374713750
case ISD::SUB:
@@ -13862,6 +13865,10 @@ struct NodeExtensionHelper {
1386213865
case ISD::MUL: {
1386313866
return Root->getValueType(0).isScalableVector();
1386413867
}
13868+
case ISD::OR: {
13869+
return Root->getValueType(0).isScalableVector() &&
13870+
Root->getFlags().hasDisjoint();
13871+
}
1386513872
// Vector Widening Integer Add/Sub/Mul Instructions
1386613873
case RISCVISD::ADD_VL:
1386713874
case RISCVISD::MUL_VL:
@@ -13942,7 +13949,8 @@ struct NodeExtensionHelper {
1394213949
switch (Root->getOpcode()) {
1394313950
case ISD::ADD:
1394413951
case ISD::SUB:
13945-
case ISD::MUL: {
13952+
case ISD::MUL:
13953+
case ISD::OR: {
1394613954
SDLoc DL(Root);
1394713955
MVT VT = Root->getSimpleValueType(0);
1394813956
return getDefaultScalableVLOps(VT, DL, DAG, Subtarget);
@@ -13965,6 +13973,7 @@ struct NodeExtensionHelper {
1396513973
switch (N->getOpcode()) {
1396613974
case ISD::ADD:
1396713975
case ISD::MUL:
13976+
case ISD::OR:
1396813977
case RISCVISD::ADD_VL:
1396913978
case RISCVISD::MUL_VL:
1397013979
case RISCVISD::VWADD_W_VL:
@@ -14031,6 +14040,7 @@ struct CombineResult {
1403114040
case ISD::ADD:
1403214041
case ISD::SUB:
1403314042
case ISD::MUL:
14043+
case ISD::OR:
1403414044
Merge = DAG.getUNDEF(Root->getValueType(0));
1403514045
break;
1403614046
}
@@ -14181,6 +14191,7 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1418114191
switch (Root->getOpcode()) {
1418214192
case ISD::ADD:
1418314193
case ISD::SUB:
14194+
case ISD::OR:
1418414195
case RISCVISD::ADD_VL:
1418514196
case RISCVISD::SUB_VL:
1418614197
case RISCVISD::FADD_VL:
@@ -14224,9 +14235,9 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1422414235

1422514236
/// Combine a binary operation to its equivalent VW or VW_W form.
1422614237
/// The supported combines are:
14227-
/// add_vl -> vwadd(u) | vwadd(u)_w
14228-
/// sub_vl -> vwsub(u) | vwsub(u)_w
14229-
/// mul_vl -> vwmul(u) | vwmul_su
14238+
/// add | add_vl | or disjoint -> vwadd(u) | vwadd(u)_w
14239+
/// sub | sub_vl -> vwsub(u) | vwsub(u)_w
14240+
/// mul | mul_vl -> vwmul(u) | vwmul_su
1423014241
/// fadd_vl -> vfwadd | vfwadd_w
1423114242
/// fsub_vl -> vfwsub | vfwsub_w
1423214243
/// fmul_vl -> vfwmul
@@ -15886,8 +15897,11 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1588615897
}
1588715898
case ISD::AND:
1588815899
return performANDCombine(N, DCI, Subtarget);
15889-
case ISD::OR:
15900+
case ISD::OR: {
15901+
if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
15902+
return V;
1589015903
return performORCombine(N, DCI, Subtarget);
15904+
}
1589115905
case ISD::XOR:
1589215906
return performXORCombine(N, DAG, Subtarget);
1589315907
case ISD::MUL:

llvm/test/CodeGen/RISCV/rvv/vwadd-sdnode.ll

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,18 +1394,15 @@ define <vscale x 1 x i64> @i1_zext(<vscale x 1 x i1> %va, <vscale x 1 x i64> %vb
13941394
}
13951395

13961396
; %x.i32 and %y.i32 are disjoint, so DAGCombiner will combine it into an or.
1397-
; FIXME: We should be able to recover the or into vwaddu.vv if the disjoint
1398-
; flag is set.
1397+
; Check that we combine disjoint ors into vwaddu.
13991398
define <vscale x 2 x i32> @disjoint_or(<vscale x 2 x i8> %x.i8, <vscale x 2 x i8> %y.i8) {
14001399
; CHECK-LABEL: disjoint_or:
14011400
; CHECK: # %bb.0:
14021401
; CHECK-NEXT: vsetvli a0, zero, e16, mf2, ta, ma
14031402
; CHECK-NEXT: vzext.vf2 v10, v8
1404-
; CHECK-NEXT: vsll.vi v8, v10, 8
1405-
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
1406-
; CHECK-NEXT: vzext.vf2 v10, v8
1407-
; CHECK-NEXT: vzext.vf4 v8, v9
1408-
; CHECK-NEXT: vor.vv v8, v10, v8
1403+
; CHECK-NEXT: vsll.vi v10, v10, 8
1404+
; CHECK-NEXT: vzext.vf2 v11, v9
1405+
; CHECK-NEXT: vwaddu.vv v8, v10, v11
14091406
; CHECK-NEXT: ret
14101407
%x.i16 = zext <vscale x 2 x i8> %x.i8 to <vscale x 2 x i16>
14111408
%x.shl = shl <vscale x 2 x i16> %x.i16, shufflevector(<vscale x 2 x i16> insertelement(<vscale x 2 x i16> poison, i16 8, i32 0), <vscale x 2 x i16> poison, <vscale x 2 x i32> zeroinitializer)

0 commit comments

Comments
 (0)