Skip to content

[RISCV] Teach combineBinOpOfZExt to narrow based on known bits #86680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12944,32 +12944,50 @@ static SDValue transformAddImmMulImm(SDNode *N, SelectionDAG &DAG,
static SDValue combineBinOpOfZExt(SDNode *N, SelectionDAG &DAG) {

EVT VT = N->getValueType(0);
if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT))
if (!VT.isVector() || !DAG.getTargetLoweringInfo().isTypeLegal(VT) ||
VT.getScalarSizeInBits() <= 8)
return SDValue();

SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
if (N0.getOpcode() != ISD::ZERO_EXTEND || N1.getOpcode() != ISD::ZERO_EXTEND)
return SDValue();
// TODO: Can relax these checks when we're not needing to insert a new extend
// on one side or the other..
if (!N0.hasOneUse() || !N1.hasOneUse())
return SDValue();

SDValue Src0 = N0.getOperand(0);
SDValue Src1 = N1.getOperand(0);
EVT SrcVT = Src0.getValueType();
if (!DAG.getTargetLoweringInfo().isTypeLegal(SrcVT) ||
SrcVT != Src1.getValueType() || SrcVT.getScalarSizeInBits() < 8 ||
SrcVT.getScalarSizeInBits() >= VT.getScalarSizeInBits() / 2)
EVT Src0VT = Src0.getValueType();
EVT Src1VT = Src0.getValueType();

if (!DAG.getTargetLoweringInfo().isTypeLegal(Src0VT) ||
!DAG.getTargetLoweringInfo().isTypeLegal(Src1VT))
return SDValue();

unsigned HalfBitWidth = VT.getScalarSizeInBits() / 2;
if (Src0VT.getScalarSizeInBits() >= HalfBitWidth) {
KnownBits Known = DAG.computeKnownBits(Src0);
if (Known.countMinLeadingZeros() <= HalfBitWidth)
return SDValue();
}
if (Src1VT.getScalarSizeInBits() >= HalfBitWidth) {
KnownBits Known = DAG.computeKnownBits(Src0);
if (Known.countMinLeadingZeros() <= HalfBitWidth)
return SDValue();
}

LLVMContext &C = *DAG.getContext();
EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
EVT ElemVT = EVT::getIntegerVT(C, HalfBitWidth);
EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());

Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
if (Src0VT != NarrowVT)
Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
if (Src1VT != NarrowVT)
Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);

// Src0 and Src1 are zero extended, so they're always positive if signed.
// Src0 and Src1 are always positive if signed.
//
// sub can produce a negative from two positive operands, so it needs sign
// extended. Other nodes produce a positive from two positive operands, so
Expand Down
45 changes: 21 additions & 24 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-sad.ll
Original file line number Diff line number Diff line change
Expand Up @@ -98,41 +98,38 @@ define signext i32 @sad_2block_16xi8_as_i32(ptr %a, ptr %b, i32 signext %stridea
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a1)
; CHECK-NEXT: vminu.vv v10, v8, v9
; CHECK-NEXT: vmaxu.vv v8, v8, v9
; CHECK-NEXT: vsub.vv v8, v8, v10
; CHECK-NEXT: add a0, a0, a2
; CHECK-NEXT: add a1, a1, a3
; CHECK-NEXT: vle8.v v10, (a0)
; CHECK-NEXT: vle8.v v11, (a1)
; CHECK-NEXT: vminu.vv v12, v8, v9
; CHECK-NEXT: vmaxu.vv v8, v8, v9
; CHECK-NEXT: vsub.vv v8, v8, v12
; CHECK-NEXT: vminu.vv v9, v10, v11
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vle8.v v10, (a1)
; CHECK-NEXT: add a0, a0, a2
; CHECK-NEXT: add a1, a1, a3
; CHECK-NEXT: vle8.v v11, (a0)
; CHECK-NEXT: vle8.v v12, (a1)
; CHECK-NEXT: vminu.vv v13, v9, v10
; CHECK-NEXT: vmaxu.vv v9, v9, v10
; CHECK-NEXT: vsub.vv v9, v9, v13
; CHECK-NEXT: vminu.vv v10, v11, v12
; CHECK-NEXT: vmaxu.vv v11, v11, v12
; CHECK-NEXT: add a0, a0, a2
; CHECK-NEXT: add a1, a1, a3
; CHECK-NEXT: vle8.v v12, (a0)
; CHECK-NEXT: vle8.v v13, (a1)
; CHECK-NEXT: vmaxu.vv v10, v10, v11
; CHECK-NEXT: vsub.vv v9, v10, v9
; CHECK-NEXT: vwaddu.vv v10, v9, v8
; CHECK-NEXT: vsub.vv v10, v11, v10
; CHECK-NEXT: vwaddu.vv v14, v9, v8
; CHECK-NEXT: vwaddu.wv v14, v14, v10
; CHECK-NEXT: vminu.vv v8, v12, v13
; CHECK-NEXT: vmaxu.vv v9, v12, v13
; CHECK-NEXT: vsub.vv v8, v9, v8
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; CHECK-NEXT: add a0, a0, a2
; CHECK-NEXT: add a1, a1, a3
; CHECK-NEXT: vle8.v v9, (a0)
; CHECK-NEXT: vle8.v v12, (a1)
; CHECK-NEXT: vzext.vf2 v14, v8
; CHECK-NEXT: vwaddu.vv v16, v14, v10
; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma
; CHECK-NEXT: vminu.vv v8, v9, v12
; CHECK-NEXT: vmaxu.vv v9, v9, v12
; CHECK-NEXT: vsub.vv v8, v9, v8
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; CHECK-NEXT: vzext.vf2 v10, v8
; CHECK-NEXT: vwaddu.wv v16, v16, v10
; CHECK-NEXT: vwaddu.wv v14, v14, v8
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
; CHECK-NEXT: vmv.s.x v8, zero
; CHECK-NEXT: vredsum.vs v8, v16, v8
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
; CHECK-NEXT: vwredsumu.vs v8, v14, v8
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
; CHECK-NEXT: vmv.x.s a0, v8
; CHECK-NEXT: ret
entry:
Expand Down
7 changes: 4 additions & 3 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwaddu.ll
Original file line number Diff line number Diff line change
Expand Up @@ -403,11 +403,12 @@ define <2 x i32> @vwaddu_v2i32_v2i8(ptr %x, ptr %y) {
define <4 x i32> @vwaddu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
; CHECK-LABEL: vwaddu_v4i32_v4i8_v4i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 4, e16, mf2, ta, ma
; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle16.v v9, (a1)
; CHECK-NEXT: vzext.vf2 v10, v8
; CHECK-NEXT: vwaddu.vv v8, v10, v9
; CHECK-NEXT: vwaddu.wv v9, v9, v8
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
; CHECK-NEXT: vzext.vf2 v8, v9
; CHECK-NEXT: ret
%a = load <4 x i8>, ptr %x
%b = load <4 x i16>, ptr %y
Expand Down
4 changes: 3 additions & 1 deletion llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ define <4 x i32> @vwsubu_v4i32_v4i8_v4i16(ptr %x, ptr %y) {
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle16.v v9, (a1)
; CHECK-NEXT: vzext.vf2 v10, v8
; CHECK-NEXT: vwsubu.vv v8, v10, v9
; CHECK-NEXT: vsub.vv v9, v10, v9
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma
; CHECK-NEXT: vsext.vf2 v8, v9
; CHECK-NEXT: ret
%a = load <4 x i8>, ptr %x
%b = load <4 x i16>, ptr %y
Expand Down