Skip to content

Commit 373e77b

Browse files
authored
[RISCV] Generalize (sub zext, zext) -> (sext (sub zext, zext)) to add (#86248)
This generalizes the combine added in #82455 to other binary ops, beginning with adds in this patch. Because the two zext operands are always +ve when treated as signed, and we don't get any overflow since the add is carried out in at least N * 2 bits of the narrow type, the result of the add will always be +ve. So we can use a zext for the outer extend, unlike sub which may produce a -ve result from two +ve operands. Although we could still use sext for add, I plan to add support for other binary ops like mul in a later patch, but mul requires zext to be correct (because the maximum value will take up the full N * 2 bits). So I've opted to use zext here too for consistency. Alive2 proof: https://alive2.llvm.org/ce/z/PRNsUM
1 parent d9746a6 commit 373e77b

File tree

3 files changed

+125
-110
lines changed

3 files changed

+125
-110
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12913,6 +12913,55 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
1291312913
return DAG.getNode(ISD::ADD, DL, VT, New1, DAG.getConstant(CB, DL, VT));
1291412914
}
1291512915

12916+
// add (zext, zext) -> zext (add (zext, zext))
12917+
// sub (zext, zext) -> sext (sub (zext, zext))
12918+
//
12919+
// where the sum of the extend widths match, and the the range of the bin op
12920+
// fits inside the width of the narrower bin op. (For profitability on rvv, we
12921+
// use a power of two for both inner and outer extend.)
12922+
//
12923+
// TODO: Extend this to other binary ops
12924+
static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {
12925+
12926+
EVT VT = N->getValueType(0);
12927+
if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
12928+
return SDValue();
12929+
12930+
SDValue N0 = N->getOperand(0);
12931+
SDValue N1 = N->getOperand(1);
12932+
if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
12933+
return SDValue();
12934+
if (!N0.hasOneUse() || !N1.hasOneUse())
12935+
return SDValue();
12936+
12937+
SDValue Src0 = N0.getOperand(0);
12938+
SDValue Src1 = N1.getOperand(0);
12939+
EVT SrcVT = Src0.getValueType();
12940+
if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT) ||
12941+
SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
12942+
SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
12943+
return SDValue();
12944+
12945+
LLVMContext &C = *DAG.getContext();
12946+
EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
12947+
EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
12948+
12949+
Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
12950+
Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
12951+
12952+
// Src0 and Src1 are zero extended, so they're always positive if signed.
12953+
//
12954+
// sub can produce a negative from two positive operands, so it needs sign
12955+
// extended. Other nodes produce a positive from two positive operands, so
12956+
// zero extend instead.
12957+
unsigned OuterExtend =
12958+
N->getOpcode() == ISD::SUB ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
12959+
12960+
return DAG.getNode(
12961+
OuterExtend, SDLoc(N), VT,
12962+
DAG.getNode(N->getOpcode(), SDLoc(N), NarrowVT, Src0, Src1));
12963+
}
12964+
1291612965
// Try to turn (add (xor bool, 1) -1) into (neg bool).
1291712966
static SDValue combineAddOfBooleanXor(SDNode *N, SelectionDAG &DAG) {
1291812967
SDValue N0 = N->getOperand(0);
@@ -12950,6 +12999,8 @@ static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,
1295012999
return V;
1295113000
if (SDValue V = combineBinOpOfExtractToReduceTree(N, DAG, Subtarget))
1295213001
return V;
13002+
if (SDValue V = combineBinOpOfZExt(N, DAG))
13003+
return V;
1295313004

1295413005
// fold (add (select lhs, rhs, cc, 0, y), x) ->
1295513006
// (select lhs, rhs, cc, x, (add x, y))
@@ -13017,28 +13068,8 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
1301713068
}
1301813069
}
1301913070

13020-
// sub (zext, zext) -> sext (sub (zext, zext))
13021-
// where the sum of the extend widths match, and the inner zexts
13022-
// add at least one bit. (For profitability on rvv, we use a
13023-
// power of two for both inner and outer extend.)
13024-
if (VT.isVector() && Subtarget.getTargetLowering()->isTypeLegal(VT) &&
13025-
N0.getOpcode() == N1.getOpcode() && N0.getOpcode() == ISD::ZERO_EXTEND &&
13026-
N0.hasOneUse() && N1.hasOneUse()) {
13027-
SDValue Src0 = N0.getOperand(0);
13028-
SDValue Src1 = N1.getOperand(0);
13029-
EVT SrcVT = Src0.getValueType();
13030-
if (Subtarget.getTargetLowering()->isTypeLegal(SrcVT) &&
13031-
SrcVT == Src1.getValueType() && SrcVT.getScalarSizeInBits() >= 8 &&
13032-
SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2) {
13033-
LLVMContext &C = *DAG.getContext();
13034-
EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
13035-
EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
13036-
Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
13037-
Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
13038-
return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
13039-
DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
13040-
}
13041-
}
13071+
if (SDValue V = combineBinOpOfZExt(N, DAG))
13072+
return V;
1304213073

1304313074
// fold (sub x, (select lhs, rhs, cc, 0, y)) ->
1304413075
// (select lhs, rhs, cc, x, (sub x, y))

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -385,12 +385,12 @@ define <32 x i64> @vwaddu_v32i64(ptr %x, ptr %y) nounwind {
385385
define <2 x i32> @vwaddu_v2i32_v2i8(ptr %x, ptr %y) {
386386
; CHECK-LABEL: vwaddu_v2i32_v2i8:
387387
; CHECK: # %bb.0:
388-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
388+
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
389389
; CHECK-NEXT: vle8.v v8, (a0)
390390
; CHECK-NEXT: vle8.v v9, (a1)
391-
; CHECK-NEXT: vzext.vf2 v10, v8
392-
; CHECK-NEXT: vzext.vf2 v11, v9
393-
; CHECK-NEXT: vwaddu.vv v8, v10, v11
391+
; CHECK-NEXT: vwaddu.vv v10, v8, v9
392+
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
393+
; CHECK-NEXT: vzext.vf2 v8, v10
394394
; CHECK-NEXT: ret
395395
%a = load <2 x i8>, ptr %x
396396
%b = load <2 x i8>, ptr %y
@@ -912,12 +912,12 @@ define <4 x i64> @crash(<4 x i16> %x, <4 x i16> %y) {
912912
define <2 x i32> @vwaddu_v2i32_of_v2i8(ptr %x, ptr %y) {
913913
; CHECK-LABEL: vwaddu_v2i32_of_v2i8:
914914
; CHECK: # %bb.0:
915-
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
915+
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
916916
; CHECK-NEXT: vle8.v v8, (a0)
917917
; CHECK-NEXT: vle8.v v9, (a1)
918-
; CHECK-NEXT: vzext.vf2 v10, v8
919-
; CHECK-NEXT: vzext.vf2 v11, v9
920-
; CHECK-NEXT: vwaddu.vv v8, v10, v11
918+
; CHECK-NEXT: vwaddu.vv v10, v8, v9
919+
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
920+
; CHECK-NEXT: vzext.vf2 v8, v10
921921
; CHECK-NEXT: ret
922922
%a = load <2 x i8>, ptr %x
923923
%b = load <2 x i8>, ptr %y
@@ -930,12 +930,12 @@ define <2 x i32> @vwaddu_v2i32_of_v2i8(ptr %x, ptr %y) {
930930
define <2 x i64> @vwaddu_v2i64_of_v2i8(ptr %x, ptr %y) {
931931
; CHECK-LABEL: vwaddu_v2i64_of_v2i8:
932932
; CHECK: # %bb.0:
933-
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
933+
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
934934
; CHECK-NEXT: vle8.v v8, (a0)
935935
; CHECK-NEXT: vle8.v v9, (a1)
936-
; CHECK-NEXT: vzext.vf4 v10, v8
937-
; CHECK-NEXT: vzext.vf4 v11, v9
938-
; CHECK-NEXT: vwaddu.vv v8, v10, v11
936+
; CHECK-NEXT: vwaddu.vv v10, v8, v9
937+
; CHECK-NEXT: vsetvli zero, zero, e64, m1, ta, ma
938+
; CHECK-NEXT: vzext.vf4 v8, v10
939939
; CHECK-NEXT: ret
940940
%a = load <2 x i8>, ptr %x
941941
%b = load <2 x i8>, ptr %y
@@ -948,12 +948,12 @@ define <2 x i64> @vwaddu_v2i64_of_v2i8(ptr %x, ptr %y) {
948948
define <2 x i64> @vwaddu_v2i64_of_v2i16(ptr %x, ptr %y) {
949949
; CHECK-LABEL: vwaddu_v2i64_of_v2i16:
950950
; CHECK: # %bb.0:
951-
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
951+
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
952952
; CHECK-NEXT: vle16.v v8, (a0)
953953
; CHECK-NEXT: vle16.v v9, (a1)
954-
; CHECK-NEXT: vzext.vf2 v10, v8
955-
; CHECK-NEXT: vzext.vf2 v11, v9
956-
; CHECK-NEXT: vwaddu.vv v8, v10, v11
954+
; CHECK-NEXT: vwaddu.vv v10, v8, v9
955+
; CHECK-NEXT: vsetvli zero, zero, e64, m1, ta, ma
956+
; CHECK-NEXT: vzext.vf2 v8, v10
957957
; CHECK-NEXT: ret
958958
%a = load <2 x i16>, ptr %x
959959
%b = load <2 x i16>, ptr %y

0 commit comments

Comments
 (0)