-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[RISCV] Undo unprofitable zext of icmp combine #134306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-risc-v Author: Luke Lau (lukel97) ChangesInstCombine will combine this zext of an icmp where the source has a single bit set to a lshr plus trunc ( define <vscale x 1 x i8> @<!-- -->f(<vscale x 1 x i64> %x) {
%1 = and <vscale x 1 x i64> %x, splat (i64 8)
%2 = icmp ne <vscale x 1 x i64> %1, splat (i64 0)
%3 = zext <vscale x 1 x i1> %2 to <vscale x 1 x i8>
ret <vscale x 1 x i8> %3
} define <vscale x 1 x i8> @<!-- -->f(<vscale x 1 x i64> %x) #<!-- -->0 {
%1 = and <vscale x 1 x i64> %x, splat (i64 8)
%.lobit = lshr exact <vscale x 1 x i64> %1, splat (i64 3)
%2 = trunc nuw nsw <vscale x 1 x i64> %.lobit to <vscale x 1 x i8>
ret <vscale x 1 x i8> %2
} In a loop, this ends up being unprofitable for RISC-V because the codegen now goes from: f: # @<!-- -->f
.cfi_startproc
# %bb.0:
vsetvli a0, zero, e64, m1, ta, ma
vand.vi v8, v8, 8
vmsne.vi v0, v8, 0
vsetvli zero, zero, e8, mf8, ta, ma
vmv.v.i v8, 0
vmerge.vim v8, v8, 1, v0
ret To a series of narrowing vnsrl.wis: f: # @<!-- -->f
.cfi_startproc
# %bb.0:
vsetvli a0, zero, e64, m1, ta, ma
vand.vi v8, v8, 8
vsetvli zero, zero, e32, mf2, ta, ma
vnsrl.wi v8, v8, 3
vsetvli zero, zero, e16, mf4, ta, ma
vnsrl.wi v8, v8, 0
vsetvli zero, zero, e8, mf8, ta, ma
vnsrl.wi v8, v8, 0
ret In the original form, the vmv.v.i is loop invariant and is hoisted out, and the vmerge.vim usually gets folded away into a masked instruction, so you usually just end up with a vsetvli + vmsne.vi. The truncate requires multiple instructions and introduces a vtype toggle for each one, and is measurably slower on the BPI-F3. This reverses the transform in RISCVCodeGenPrepare for truncations greater than twice the bitwidth, i.e. it keeps single vnsrl.wis. Fixes #132245 Full diff: https://github.com/llvm/llvm-project/pull/134306.diff 3 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
index b5cb05f30fb26..e04f3b1d3478e 100644
--- a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
+++ b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
@@ -25,6 +25,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
+#include "llvm/Transforms/Utils/Local.h"
using namespace llvm;
@@ -62,10 +63,74 @@ class RISCVCodeGenPrepare : public FunctionPass,
} // end anonymous namespace
-// Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
-// but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
-// the upper 32 bits with ones.
+// InstCombinerImpl::transformZExtICmp will narrow a zext of an icmp with a
+// truncation. But RVV doesn't have truncation instructions for more than twice
+// the bitwidth.
+//
+// E.g. trunc <vscale x 1 x i64> %x to <vscale x 1 x i8> will generate:
+//
+// vsetvli a0, zero, e32, m2, ta, ma
+// vnsrl.wi v12, v8, 0
+// vsetvli zero, zero, e16, m1, ta, ma
+// vnsrl.wi v8, v12, 0
+// vsetvli zero, zero, e8, mf2, ta, ma
+// vnsrl.wi v8, v8, 0
+//
+// So reverse the combine so we generate an vmseq/vmsne again:
+//
+// and (lshr (trunc X), ShAmt), 1
+// -->
+// zext (icmp ne (and X, (1 << ShAmt)), 0)
+//
+// and (lshr (not (trunc X)), ShAmt), 1
+// -->
+// zext (icmp eq (and X, (1 << ShAmt)), 0)
+static bool reverseZExtICmpCombine(BinaryOperator &BO) {
+ using namespace PatternMatch;
+
+ assert(BO.getOpcode() == BinaryOperator::And);
+
+ if (!BO.getType()->isVectorTy())
+ return false;
+ const APInt *ShAmt;
+ Value *Inner;
+ if (!match(&BO,
+ m_And(m_OneUse(m_LShr(m_OneUse(m_Value(Inner)), m_APInt(ShAmt))),
+ m_One())))
+ return false;
+
+ Value *X;
+ bool IsNot;
+ if (match(Inner, m_Not(m_Trunc(m_Value(X)))))
+ IsNot = true;
+ else if (match(Inner, m_Trunc(m_Value(X))))
+ IsNot = false;
+ else
+ return false;
+
+ if (BO.getType()->getScalarSizeInBits() >=
+ X->getType()->getScalarSizeInBits() / 2)
+ return false;
+
+ IRBuilder<> Builder(&BO);
+ Value *Res = Builder.CreateAnd(
+ X, ConstantInt::get(X->getType(), 1 << ShAmt->getZExtValue()));
+ Res = Builder.CreateICmp(IsNot ? CmpInst::Predicate::ICMP_EQ
+ : CmpInst::Predicate::ICMP_NE,
+ Res, ConstantInt::get(X->getType(), 0));
+ Res = Builder.CreateZExt(Res, BO.getType());
+ BO.replaceAllUsesWith(Res);
+ RecursivelyDeleteTriviallyDeadInstructions(&BO);
+ return true;
+}
+
bool RISCVCodeGenPrepare::visitAnd(BinaryOperator &BO) {
+ if (reverseZExtICmpCombine(BO))
+ return true;
+
+ // Try to optimize (i64 (and (zext/sext (i32 X), C1))) if C1 has bit 31 set,
+ // but bits 63:32 are zero. If we know that bit 31 of X is 0, we can fill
+ // the upper 32 bits with ones.
if (!ST->is64Bit())
return false;
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
index 4e5f6e0f65489..b6593eac6d92c 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare-asm.ll
@@ -498,3 +498,84 @@ vector.body: ; preds = %vector.body, %entry
for.cond.cleanup: ; preds = %vector.body
ret float %red
}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i16(<vscale x 1 x i16> %x) {
+; CHECK-LABEL: reverse_zexticmp_i16:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e8, mf8, ta, ma
+; CHECK-NEXT: vnsrl.wi v8, v8, 0
+; CHECK-NEXT: vsrl.vi v8, v8, 2
+; CHECK-NEXT: vand.vi v8, v8, 1
+; CHECK-NEXT: ret
+ %1 = trunc <vscale x 1 x i16> %x to <vscale x 1 x i8>
+ %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+ %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+ ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: reverse_zexticmp_i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vand.vi v8, v8, 4
+; CHECK-NEXT: vmsne.vi v0, v8, 0
+; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT: ret
+ %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+ %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+ %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+ ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: reverse_zexticmp_neg_i32:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e32, mf2, ta, ma
+; CHECK-NEXT: vand.vi v8, v8, 4
+; CHECK-NEXT: vmseq.vi v0, v8, 0
+; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT: ret
+ %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+ %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+ %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+ %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+ ret <vscale x 1 x i8> %4
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: reverse_zexticmp_i64:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT: vand.vi v8, v8, 4
+; CHECK-NEXT: vmsne.vi v0, v8, 0
+; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT: ret
+ %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+ %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+ %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+ ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: reverse_zexticmp_neg_i64:
+; CHECK: # %bb.0:
+; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT: vand.vi v8, v8, 4
+; CHECK-NEXT: vmseq.vi v0, v8, 0
+; CHECK-NEXT: vsetvli zero, zero, e8, mf8, ta, ma
+; CHECK-NEXT: vmv.v.i v8, 0
+; CHECK-NEXT: vmerge.vim v8, v8, 1, v0
+; CHECK-NEXT: ret
+ %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+ %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+ %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+ %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+ ret <vscale x 1 x i8> %4
+}
+
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
index 8967fb8bf01ac..483e797151325 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
@@ -528,3 +528,75 @@ vector.body: ; preds = %vector.body, %entry
for.cond.cleanup: ; preds = %vector.body
ret float %red
}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i16(<vscale x 1 x i16> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i16(
+; CHECK-SAME: <vscale x 1 x i16> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT: [[TMP1:%.*]] = trunc <vscale x 1 x i16> [[X]] to <vscale x 1 x i8>
+; CHECK-NEXT: [[TMP2:%.*]] = lshr <vscale x 1 x i8> [[TMP1]], splat (i8 2)
+; CHECK-NEXT: [[TMP3:%.*]] = and <vscale x 1 x i8> [[TMP2]], splat (i8 1)
+; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP3]]
+;
+ %1 = trunc <vscale x 1 x i16> %x to <vscale x 1 x i8>
+ %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+ %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+ ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i32(
+; CHECK-SAME: <vscale x 1 x i32> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i32> [[X]], splat (i32 4)
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <vscale x 1 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP3]]
+;
+ %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+ %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+ %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+ ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(<vscale x 1 x i32> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_neg_i32(
+; CHECK-SAME: <vscale x 1 x i32> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i32> [[X]], splat (i32 4)
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <vscale x 1 x i32> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP4:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP4]]
+;
+ %1 = trunc <vscale x 1 x i32> %x to <vscale x 1 x i8>
+ %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+ %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+ %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+ ret <vscale x 1 x i8> %4
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_i64(
+; CHECK-SAME: <vscale x 1 x i64> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i64> [[X]], splat (i64 4)
+; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <vscale x 1 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP3]]
+;
+ %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+ %2 = lshr <vscale x 1 x i8> %1, splat (i8 2)
+ %3 = and <vscale x 1 x i8> %2, splat (i8 1)
+ ret <vscale x 1 x i8> %3
+}
+
+define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(<vscale x 1 x i64> %x) {
+; CHECK-LABEL: define <vscale x 1 x i8> @reverse_zexticmp_neg_i64(
+; CHECK-SAME: <vscale x 1 x i64> [[X:%.*]]) #[[ATTR2]] {
+; CHECK-NEXT: [[TMP1:%.*]] = and <vscale x 1 x i64> [[X]], splat (i64 4)
+; CHECK-NEXT: [[TMP2:%.*]] = icmp eq <vscale x 1 x i64> [[TMP1]], zeroinitializer
+; CHECK-NEXT: [[TMP4:%.*]] = zext <vscale x 1 x i1> [[TMP2]] to <vscale x 1 x i8>
+; CHECK-NEXT: ret <vscale x 1 x i8> [[TMP4]]
+;
+ %1 = trunc <vscale x 1 x i64> %x to <vscale x 1 x i8>
+ %2 = xor <vscale x 1 x i8> %1, splat (i8 -1)
+ %3 = lshr <vscale x 1 x i8> %2, splat (i8 2)
+ %4 = and <vscale x 1 x i8> %3, splat (i8 1)
+ ret <vscale x 1 x i8> %4
+}
|
Can this be done in DAGCombine or does it need something that only codegen prepare can do? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
static bool reverseZExtICmpCombine(BinaryOperator &BO) { | ||
using namespace PatternMatch; | ||
|
||
assert(BO.getOpcode() == BinaryOperator::And); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could generalize this using demanded bits, but I'm not sure if doing so is actually worthwhile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for what it's worth the case I'm specifically trying to catch in InstCombine seems to only match ands anyway:
llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
Lines 1022 to 1029 in 85fdab3
if (Cmp->isEquality()) { | |
// Test if a bit is clear/set using a shifted-one mask: | |
// zext (icmp eq (and X, (1 << ShAmt)), 0) --> and (lshr (not X), ShAmt), 1 | |
// zext (icmp ne (and X, (1 << ShAmt)), 0) --> and (lshr X, ShAmt), 1 | |
Value *X, *ShAmt; | |
if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) && | |
match(Cmp->getOperand(0), | |
m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) { |
It can be done in DAGCombine, I can move it over if that's preferred |
// --> | ||
// zext (icmp eq (and X, (1 << ShAmt)), 0) | ||
static bool reverseZExtICmpCombine(BinaryOperator &BO) { | ||
using namespace PatternMatch; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be checking for that the vector extensions are enabled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I suggest done it in DAGCombine. |
SDValue Res = | ||
DAG.getNode(ISD::AND, DL, WideVT, X, | ||
DAG.getConstant(1 << ShAmt.getZExtValue(), DL, WideVT)); | ||
Res = DAG.getSetCC(DL, WideVT.changeElementType(MVT::i1), Res, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use changeElementType. It's half broken. If WideVT happens to be a simple VT, but the VT with the element type changed is not simple, it will fail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
InstCombine will combine this zext of an icmp where the source has a single bit set to a lshr plus trunc (
InstCombinerImpl::transformZExtICmp
):In a loop, this ends up being unprofitable for RISC-V because the codegen now goes from:
To a series of narrowing vnsrl.wis:
In the original form, the vmv.v.i is loop invariant and is hoisted out, and the vmerge.vim usually gets folded away into a masked instruction, so you usually just end up with a vsetvli + vmsne.vi.
The truncate requires multiple instructions and introduces a vtype toggle for each one, and is measurably slower on the BPI-F3.
This reverses the transform in RISCVISelLowering for truncations greater than twice the bitwidth, i.e. it keeps single vnsrl.wis.
Fixes #132245