Skip to content

[RISCV] Vector sub (zext, zext) -> sext (sub (zext, zext)) #82455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12846,21 +12846,44 @@ static SDValue performSUBCombine(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combineSubOfBoolean(N, DAG))
return V;

EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
// fold (sub 0, (setcc x, 0, setlt)) -> (sra x, xlen - 1)
if (isNullConstant(N0) && N1.getOpcode() == ISD::SETCC && N1.hasOneUse() &&
isNullConstant(N1.getOperand(1))) {
ISD::CondCode CCVal = cast<CondCodeSDNode>(N1.getOperand(2))->get();
if (CCVal == ISD::SETLT) {
EVT VT = N->getValueType(0);
SDLoc DL(N);
unsigned ShAmt = N0.getValueSizeInBits() - 1;
return DAG.getNode(ISD::SRA, DL, VT, N1.getOperand(0),
DAG.getConstant(ShAmt, DL, VT));
}
}

// sub (zext, zext) -> sext (sub (zext, zext))
// where the sum of the extend widths match, and the inner zexts
// add at least one bit. (For profitability on rvv, we use a
// power of two for both inner and outer extend.)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are we guaranteeing power of 2 here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a case of a comment being out of sync with my mental model. Will fix, but let me explain what I'm thinking and you can tell me if this is sane or not.

I was originally intending to have a isTypeLegal check on both VT and SrcVT. That combined with the srcvt > 8 check should ensure that all of the types are e8, e16, e32, or e64.

Then I started thinking about illegal types. I think they fall into two camps - reasonable ones such as i128, and odd ones such as i34. For the former, narrowing before legalization (splitting, I think?) seems likely profitable. For the later, we might end up with an e.g. e17 intermediate type, but that'll get promoted to i32 and i64 respectively. So, reasonable overall result? (Though, I now notice there's an edge case here with e.g. i33 not having a half sized type.)

What do you think, should I fix the edge case and allow illegal types? Require legal types? Something else?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getHalfSizedIntegerVT rounds up odd types so the new type will cover at least half so it will return i17 for i33.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the result type is i128, any operation with the i128 element type as either source or dest will get split repeated until it can be scalarized, then the resulting scalar ops with illegal scalar types will get further legalized to XLen. CodeGen will be so bad I'm not sure its worth optimizing.

For other illegal types, they should get promoted to the next power of 2. After that your combine would run again have another chance at it. So it might be fine to check for legal types?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the legality checks.

For my understanding, doesn't DAG combine run before type legalize (as well as after)? Given that, wouldn't narrowing a i128 add to i64 mean that only the sext would be legalized?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, doesn't DAG combine run before type legalize (as well as after)? Given that, wouldn't narrowing a i128 add to i64 mean that only the sext would be legalized?

Yes. We'd only scalarize the sext. My thinking was that if have to generate slidedowns and extracts to scalarlize some part of the calculation, it didn't make much sense to use vectors in the first place. But I guess it depends on how many elements and how much computation is before the sub(zext, zext).

if (VT.isVector() && Subtarget.getTargetLowering()->isTypeLegal(VT) &&
N0.getOpcode() == N1.getOpcode() && N0.getOpcode() == ISD::ZERO_EXTEND &&
N0.hasOneUse() && N1.hasOneUse()) {
SDValue Src0 = N0.getOperand(0);
SDValue Src1 = N1.getOperand(0);
EVT SrcVT = Src0.getValueType();
if (Subtarget.getTargetLowering()->isTypeLegal(SrcVT) &&
SrcVT == Src1.getValueType() && SrcVT.getScalarSizeInBits() >= 8 &&
SrcVT.getScalarSizeInBits() < VT.getScalarSizeInBits() / 2) {
LLVMContext &C = *DAG.getContext();
EVT ElemVT = VT.getVectorElementType().getHalfSizedIntegerVT(C);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the types are (i64 (sub (zext i8), (zext i8))) this will produce (i64 (sext (i32 (sub (zext i8), (zext i8))))). Would it be better to do the sub at i16?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combiner will iterate won't it? So first we'd produce (i64 (sext (i32 (sub (zext i8), (zext i8)))) and then the transform would run again producing (i64 (sext (i16 (sub (zext i8), (zext i8)))). (There'd be one extra step to fold the two sext together.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right I didn't think about that.

EVT NarrowVT = EVT::getVectorVT(C, ElemVT, VT.getVectorElementCount());
Src0 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src0), NarrowVT, Src0);
Src1 = DAG.getNode(ISD::ZERO_EXTEND, SDLoc(Src1), NarrowVT, Src1);
return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), VT,
DAG.getNode(ISD::SUB, SDLoc(N), NarrowVT, Src0, Src1));
}
}

// fold (sub x, (select lhs, rhs, cc, 0, y)) ->
// (select lhs, rhs, cc, x, (sub x, y))
return combineSelectAndUse(N, N1, N0, DAG, /*AllOnes*/ false, Subtarget);
Expand Down
32 changes: 16 additions & 16 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-vwsubu.ll
Original file line number Diff line number Diff line change
Expand Up @@ -385,12 +385,12 @@ define <32 x i64> @vwsubu_v32i64(ptr %x, ptr %y) nounwind {
define <2 x i32> @vwsubu_v2i32_v2i8(ptr %x, ptr %y) {
; CHECK-LABEL: vwsubu_v2i32_v2i8:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a1)
; CHECK-NEXT: vzext.vf2 v10, v8
; CHECK-NEXT: vzext.vf2 v11, v9
; CHECK-NEXT: vwsubu.vv v8, v10, v11
; CHECK-NEXT: vwsubu.vv v10, v8, v9
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
; CHECK-NEXT: vsext.vf2 v8, v10
; CHECK-NEXT: ret
%a = load <2 x i8>, ptr %x
%b = load <2 x i8>, ptr %y
Expand Down Expand Up @@ -899,12 +899,12 @@ define <2 x i64> @vwsubu_vx_v2i64_i64(ptr %x, ptr %y) nounwind {
define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
; CHECK-LABEL: vwsubu_v2i32_of_v2i8:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a1)
; CHECK-NEXT: vzext.vf2 v10, v8
; CHECK-NEXT: vzext.vf2 v11, v9
; CHECK-NEXT: vwsubu.vv v8, v10, v11
; CHECK-NEXT: vwsubu.vv v10, v8, v9
; CHECK-NEXT: vsetvli zero, zero, e32, mf2, ta, ma
; CHECK-NEXT: vsext.vf2 v8, v10
; CHECK-NEXT: ret
%a = load <2 x i8>, ptr %x
%b = load <2 x i8>, ptr %y
Expand All @@ -917,12 +917,12 @@ define <2 x i32> @vwsubu_v2i32_of_v2i8(ptr %x, ptr %y) {
define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
; CHECK-LABEL: vwsubu_v2i64_of_v2i8:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vsetivli zero, 2, e8, mf8, ta, ma
; CHECK-NEXT: vle8.v v8, (a0)
; CHECK-NEXT: vle8.v v9, (a1)
; CHECK-NEXT: vzext.vf4 v10, v8
; CHECK-NEXT: vzext.vf4 v11, v9
; CHECK-NEXT: vwsubu.vv v8, v10, v11
; CHECK-NEXT: vwsubu.vv v10, v8, v9
; CHECK-NEXT: vsetvli zero, zero, e64, m1, ta, ma
; CHECK-NEXT: vsext.vf4 v8, v10
; CHECK-NEXT: ret
%a = load <2 x i8>, ptr %x
%b = load <2 x i8>, ptr %y
Expand All @@ -935,12 +935,12 @@ define <2 x i64> @vwsubu_v2i64_of_v2i8(ptr %x, ptr %y) {
define <2 x i64> @vwsubu_v2i64_of_v2i16(ptr %x, ptr %y) {
; CHECK-LABEL: vwsubu_v2i64_of_v2i16:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetivli zero, 2, e32, mf2, ta, ma
; CHECK-NEXT: vsetivli zero, 2, e16, mf4, ta, ma
; CHECK-NEXT: vle16.v v8, (a0)
; CHECK-NEXT: vle16.v v9, (a1)
; CHECK-NEXT: vzext.vf2 v10, v8
; CHECK-NEXT: vzext.vf2 v11, v9
; CHECK-NEXT: vwsubu.vv v8, v10, v11
; CHECK-NEXT: vwsubu.vv v10, v8, v9
; CHECK-NEXT: vsetvli zero, zero, e64, m1, ta, ma
; CHECK-NEXT: vsext.vf2 v8, v10
; CHECK-NEXT: ret
%a = load <2 x i16>, ptr %x
%b = load <2 x i16>, ptr %y
Expand Down