-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV][WIP] Optimize sum of absolute differences pattern. #82722
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13176,6 +13176,61 @@ static SDValue performXORCombine(SDNode *N, SelectionDAG &DAG, | |
return combineSelectAndUseCommutative(N, DAG, /*AllOnes*/ false, Subtarget); | ||
} | ||
|
||
// Look for (abs (sub (zext X), (zext Y))). | ||
// Rewrite as (zext (sub (zext (max X, Y), (min X, Y)))) if the user is an add | ||
// or reduction add. The min/max can be done in parallel and with a lower LMUL | ||
// than the original code. The two zexts can be folded into widening sub and | ||
// widening add or widening redsum. | ||
static SDValue performABSCombine(SDNode *N, SelectionDAG &DAG) { | ||
EVT VT = N->getValueType(0); | ||
const TargetLowering &TLI = DAG.getTargetLoweringInfo(); | ||
|
||
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i32 || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't need to be fixed. Or i32. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep it's overfitting a workload. |
||
!TLI.isTypeLegal(VT)) | ||
return SDValue(); | ||
|
||
SDValue Src = N->getOperand(0); | ||
if (Src.getOpcode() != ISD::SUB || !Src.hasOneUse()) | ||
return SDValue(); | ||
|
||
// Make sure the use is an add or reduce add so the zext we create at the end | ||
// will be folded. | ||
if (!N->hasOneUse() || (N->use_begin()->getOpcode() != ISD::ADD && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of this, we can focus on the fact this allows a narrower representation. |
||
N->use_begin()->getOpcode() != ISD::VECREDUCE_ADD)) | ||
return SDValue(); | ||
|
||
// Inputs to the subtract should be zext. | ||
SDValue Op0 = Src.getOperand(0); | ||
SDValue Op1 = Src.getOperand(1); | ||
if (Op0.getOpcode() != ISD::ZERO_EXTEND || !Op0.hasOneUse() || | ||
Op1.getOpcode() != ISD::ZERO_EXTEND || !Op1.hasOneUse()) | ||
return SDValue(); | ||
|
||
Op0 = Op0.getOperand(0); | ||
Op1 = Op1.getOperand(0); | ||
|
||
// Inputs should be i8 vectors. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or i8. |
||
if (Op0.getValueType().getVectorElementType() != MVT::i8 || | ||
Op1.getValueType().getVectorElementType() != MVT::i8) | ||
return SDValue(); | ||
|
||
SDLoc DL(N); | ||
|
||
SDValue Max = DAG.getNode(ISD::UMAX, DL, Op0.getValueType(), Op0, Op1); | ||
SDValue Min = DAG.getNode(ISD::UMIN, DL, Op0.getValueType(), Op0, Op1); | ||
|
||
// The intermediate VT should be i16. | ||
EVT IntermediateVT = | ||
EVT::getVectorVT(*DAG.getContext(), MVT::i16, VT.getVectorElementCount()); | ||
|
||
Max = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Max); | ||
Min = DAG.getNode(ISD::ZERO_EXTEND, DL, IntermediateVT, Min); | ||
|
||
SDValue Sub = DAG.getNode(ISD::SUB, DL, IntermediateVT, Max, Min); | ||
|
||
return DAG.getNode(ISD::ZERO_EXTEND, DL, VT, Sub); | ||
} | ||
|
||
static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG) { | ||
EVT VT = N->getValueType(0); | ||
if (!VT.isVector()) | ||
|
@@ -15698,6 +15753,9 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, | |
DAG.getConstant(~SignBit, DL, VT)); | ||
} | ||
case ISD::ABS: { | ||
if (SDValue V = performABSCombine(N, DAG)) | ||
return V; | ||
|
||
EVT VT = N->getValueType(0); | ||
SDValue N0 = N->getOperand(0); | ||
// abs (sext) -> zext (abs) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 2 | ||
; RUN: llc < %s -mtriple=riscv64 -mattr=+v | FileCheck %s | ||
|
||
define signext i32 @sad(ptr %a, ptr %b) { | ||
; CHECK-LABEL: sad: | ||
; CHECK: # %bb.0: # %entry | ||
; CHECK-NEXT: vsetivli zero, 4, e8, mf4, ta, ma | ||
; CHECK-NEXT: vle8.v v8, (a0) | ||
; CHECK-NEXT: vle8.v v9, (a1) | ||
; CHECK-NEXT: vminu.vv v10, v8, v9 | ||
; CHECK-NEXT: vmaxu.vv v8, v8, v9 | ||
; CHECK-NEXT: vwsubu.vv v9, v8, v10 | ||
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma | ||
; CHECK-NEXT: vmv.s.x v8, zero | ||
; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, ma | ||
; CHECK-NEXT: vwredsumu.vs v8, v9, v8 | ||
; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, ma | ||
; CHECK-NEXT: vmv.x.s a0, v8 | ||
; CHECK-NEXT: ret | ||
entry: | ||
%0 = load <4 x i8>, ptr %a, align 1 | ||
%1 = zext <4 x i8> %0 to <4 x i32> | ||
%2 = load <4 x i8>, ptr %b, align 1 | ||
%3 = zext <4 x i8> %2 to <4 x i32> | ||
%4 = sub nsw <4 x i32> %1, %3 | ||
%5 = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> %4, i1 true) | ||
%6 = tail call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %5) | ||
ret i32 %6 | ||
} | ||
|
||
define signext i32 @sad2(ptr %a, ptr %b, i32 signext %stridea, i32 signext %strideb) { | ||
; CHECK-LABEL: sad2: | ||
; CHECK: # %bb.0: # %entry | ||
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma | ||
; CHECK-NEXT: vle8.v v8, (a0) | ||
; CHECK-NEXT: vle8.v v9, (a1) | ||
; CHECK-NEXT: add a0, a0, a2 | ||
; CHECK-NEXT: add a1, a1, a3 | ||
; CHECK-NEXT: vle8.v v10, (a0) | ||
; CHECK-NEXT: vle8.v v11, (a1) | ||
; CHECK-NEXT: vminu.vv v12, v8, v9 | ||
; CHECK-NEXT: vmaxu.vv v8, v8, v9 | ||
; CHECK-NEXT: vwsubu.vv v14, v8, v12 | ||
; CHECK-NEXT: vminu.vv v8, v10, v11 | ||
; CHECK-NEXT: vmaxu.vv v9, v10, v11 | ||
; CHECK-NEXT: vwsubu.vv v12, v9, v8 | ||
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma | ||
; CHECK-NEXT: add a0, a0, a2 | ||
; CHECK-NEXT: add a1, a1, a3 | ||
; CHECK-NEXT: vle8.v v16, (a0) | ||
; CHECK-NEXT: vle8.v v17, (a1) | ||
; CHECK-NEXT: vwaddu.vv v8, v12, v14 | ||
; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma | ||
; CHECK-NEXT: vminu.vv v12, v16, v17 | ||
; CHECK-NEXT: vmaxu.vv v13, v16, v17 | ||
; CHECK-NEXT: vwsubu.vv v14, v13, v12 | ||
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma | ||
; CHECK-NEXT: add a0, a0, a2 | ||
; CHECK-NEXT: add a1, a1, a3 | ||
; CHECK-NEXT: vle8.v v12, (a0) | ||
; CHECK-NEXT: vle8.v v13, (a1) | ||
; CHECK-NEXT: vwaddu.wv v8, v8, v14 | ||
; CHECK-NEXT: vsetvli zero, zero, e8, m1, ta, ma | ||
; CHECK-NEXT: vminu.vv v14, v12, v13 | ||
; CHECK-NEXT: vmaxu.vv v12, v12, v13 | ||
; CHECK-NEXT: vwsubu.vv v16, v12, v14 | ||
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma | ||
; CHECK-NEXT: vwaddu.wv v8, v8, v16 | ||
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma | ||
; CHECK-NEXT: vmv.s.x v12, zero | ||
; CHECK-NEXT: vredsum.vs v8, v8, v12 | ||
; CHECK-NEXT: vmv.x.s a0, v8 | ||
; CHECK-NEXT: ret | ||
entry: | ||
%idx.ext8 = sext i32 %strideb to i64 | ||
%idx.ext = sext i32 %stridea to i64 | ||
%0 = load <16 x i8>, ptr %a, align 1 | ||
%1 = zext <16 x i8> %0 to <16 x i32> | ||
%2 = load <16 x i8>, ptr %b, align 1 | ||
%3 = zext <16 x i8> %2 to <16 x i32> | ||
%4 = sub nsw <16 x i32> %1, %3 | ||
%5 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %4, i1 true) | ||
%6 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %5) | ||
%add.ptr = getelementptr inbounds i8, ptr %a, i64 %idx.ext | ||
%add.ptr9 = getelementptr inbounds i8, ptr %b, i64 %idx.ext8 | ||
%7 = load <16 x i8>, ptr %add.ptr, align 1 | ||
%8 = zext <16 x i8> %7 to <16 x i32> | ||
%9 = load <16 x i8>, ptr %add.ptr9, align 1 | ||
%10 = zext <16 x i8> %9 to <16 x i32> | ||
%11 = sub nsw <16 x i32> %8, %10 | ||
%12 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %11, i1 true) | ||
%13 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %12) | ||
%op.rdx.1 = add i32 %13, %6 | ||
%add.ptr.1 = getelementptr inbounds i8, ptr %add.ptr, i64 %idx.ext | ||
%add.ptr9.1 = getelementptr inbounds i8, ptr %add.ptr9, i64 %idx.ext8 | ||
%14 = load <16 x i8>, ptr %add.ptr.1, align 1 | ||
%15 = zext <16 x i8> %14 to <16 x i32> | ||
%16 = load <16 x i8>, ptr %add.ptr9.1, align 1 | ||
%17 = zext <16 x i8> %16 to <16 x i32> | ||
%18 = sub nsw <16 x i32> %15, %17 | ||
%19 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %18, i1 true) | ||
%20 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %19) | ||
%op.rdx.2 = add i32 %20, %op.rdx.1 | ||
%add.ptr.2 = getelementptr inbounds i8, ptr %add.ptr.1, i64 %idx.ext | ||
%add.ptr9.2 = getelementptr inbounds i8, ptr %add.ptr9.1, i64 %idx.ext8 | ||
%21 = load <16 x i8>, ptr %add.ptr.2, align 1 | ||
%22 = zext <16 x i8> %21 to <16 x i32> | ||
%23 = load <16 x i8>, ptr %add.ptr9.2, align 1 | ||
%24 = zext <16 x i8> %23 to <16 x i32> | ||
%25 = sub nsw <16 x i32> %22, %24 | ||
%26 = tail call <16 x i32> @llvm.abs.v16i32(<16 x i32> %25, i1 true) | ||
%27 = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> %26) | ||
%op.rdx.3 = add i32 %27, %op.rdx.2 | ||
ret i32 %op.rdx.3 | ||
} | ||
|
||
declare <4 x i32> @llvm.abs.v4i32(<4 x i32>, i1) | ||
declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>) | ||
declare <16 x i32> @llvm.abs.v16i32(<16 x i32>, i1) | ||
declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't the sub be done at the narrower type as well? (a >=u b) should imply that (a-b) doesn't underflow, and thus the high bits are always zero?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can, but it would leave a bare vzext.vf2 later. I was trying to carefully create a widening sub and a widening add or widening reduction to minimize the number of individual vector operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the workload in question we can pull that vext.vf2 through through the chain of adds that sums the absolute difference pieces. We could use an i16 accumulator for the beginning of the chain and switch to an i32 accumulator later in the chain.
Naive use of the computeKnownBits could get us some of that to prove the overflows don't happen. Need to check the l length of the chain in the workload to see if that would exceed computeKnownBits depth limit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This is mostly a response to myself)
Here's the alive2 proof for the transformation in this patch:
https://alive2.llvm.org/ce/z/XoCBZ5
Note the need for noundef on the source parameters. Alternatively, we could use freeze in the target.
Here's my proposed variant:
https://alive2.llvm.org/ce/z/f7MdJe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is lowering (abs (sub X, Y)) to (sub (umax x, y), (umin x, y)) worthwhile doing on its own, ignoring pulling through the zexts for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure that's valid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking that we would check the sub doesn't overflow with computeKnownBits, e.g. that the upper bits of X and Y are zero: https://alive2.llvm.org/ce/z/MZuw8V