Skip to content

Commit 0ad275c

Browse files
authored
[InstCombine] Fold vector.reduce.op(vector.reverse(X)) -> vector.reduce.op(X) (#91743)
For all of the following reductions: vector.reduce.or vector.reduce.and vector.reduce.xor vector.reduce.add vector.reduce.mul vector.reduce.umin vector.reduce.umax vector.reduce.smin vector.reduce.smax vector.reduce.fmin vector.reduce.fmax if the input operand is the result of a vector.reverse then we can perform a reduction on the vector.reverse input instead since the answer is the same. If the reassociation is permitted we can also do the same folds for these: vector.reduce.fadd vector.reduce.fmul
1 parent 98deeda commit 0ad275c

File tree

3 files changed

+308
-20
lines changed

3 files changed

+308
-20
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 74 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,6 +1435,33 @@ static Instruction *foldBitOrderCrossLogicOp(Value *V,
14351435
return nullptr;
14361436
}
14371437

1438+
static Value *simplifyReductionOperand(Value *Arg, bool CanReorderLanes) {
1439+
if (!CanReorderLanes)
1440+
return nullptr;
1441+
1442+
Value *V;
1443+
if (match(Arg, m_VecReverse(m_Value(V))))
1444+
return V;
1445+
1446+
ArrayRef<int> Mask;
1447+
if (!isa<FixedVectorType>(Arg->getType()) ||
1448+
!match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
1449+
!cast<ShuffleVectorInst>(Arg)->isSingleSource())
1450+
return nullptr;
1451+
1452+
int Sz = Mask.size();
1453+
SmallBitVector UsedIndices(Sz);
1454+
for (int Idx : Mask) {
1455+
if (Idx == PoisonMaskElem || UsedIndices.test(Idx))
1456+
return nullptr;
1457+
UsedIndices.set(Idx);
1458+
}
1459+
1460+
// Can remove shuffle iff just shuffled elements, no repeats, undefs, or
1461+
// other changes.
1462+
return UsedIndices.all() ? V : nullptr;
1463+
}
1464+
14381465
/// CallInst simplification. This mostly only handles folding of intrinsic
14391466
/// instructions. For normal calls, it allows visitCallBase to do the heavy
14401467
/// lifting.
@@ -3223,6 +3250,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
32233250
// %res = cmp eq iReduxWidth %val, 11111
32243251
Value *Arg = II->getArgOperand(0);
32253252
Value *Vect;
3253+
3254+
if (Value *NewOp =
3255+
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
3256+
replaceUse(II->getOperandUse(0), NewOp);
3257+
return II;
3258+
}
3259+
32263260
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
32273261
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
32283262
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3254,6 +3288,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
32543288
// Trunc(ctpop(bitcast <n x i1> to in)).
32553289
Value *Arg = II->getArgOperand(0);
32563290
Value *Vect;
3291+
3292+
if (Value *NewOp =
3293+
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
3294+
replaceUse(II->getOperandUse(0), NewOp);
3295+
return II;
3296+
}
3297+
32573298
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
32583299
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
32593300
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3282,6 +3323,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
32823323
// ?ext(vector_reduce_add(<n x i1>))
32833324
Value *Arg = II->getArgOperand(0);
32843325
Value *Vect;
3326+
3327+
if (Value *NewOp =
3328+
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
3329+
replaceUse(II->getOperandUse(0), NewOp);
3330+
return II;
3331+
}
3332+
32853333
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
32863334
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
32873335
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3305,6 +3353,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
33053353
// zext(vector_reduce_and(<n x i1>))
33063354
Value *Arg = II->getArgOperand(0);
33073355
Value *Vect;
3356+
3357+
if (Value *NewOp =
3358+
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
3359+
replaceUse(II->getOperandUse(0), NewOp);
3360+
return II;
3361+
}
3362+
33083363
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
33093364
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
33103365
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3329,6 +3384,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
33293384
// ?ext(vector_reduce_{and,or}(<n x i1>))
33303385
Value *Arg = II->getArgOperand(0);
33313386
Value *Vect;
3387+
3388+
if (Value *NewOp =
3389+
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
3390+
replaceUse(II->getOperandUse(0), NewOp);
3391+
return II;
3392+
}
3393+
33323394
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
33333395
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
33343396
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3364,6 +3426,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
33643426
// zext(vector_reduce_{and,or}(<n x i1>))
33653427
Value *Arg = II->getArgOperand(0);
33663428
Value *Vect;
3429+
3430+
if (Value *NewOp =
3431+
simplifyReductionOperand(Arg, /*CanReorderLanes=*/true)) {
3432+
replaceUse(II->getOperandUse(0), NewOp);
3433+
return II;
3434+
}
3435+
33673436
if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) {
33683437
if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType()))
33693438
if (FTy->getElementType() == Builder.getInt1Ty()) {
@@ -3386,31 +3455,16 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
33863455
case Intrinsic::vector_reduce_fmin:
33873456
case Intrinsic::vector_reduce_fadd:
33883457
case Intrinsic::vector_reduce_fmul: {
3389-
bool CanBeReassociated = (IID != Intrinsic::vector_reduce_fadd &&
3390-
IID != Intrinsic::vector_reduce_fmul) ||
3391-
II->hasAllowReassoc();
3458+
bool CanReorderLanes = (IID != Intrinsic::vector_reduce_fadd &&
3459+
IID != Intrinsic::vector_reduce_fmul) ||
3460+
II->hasAllowReassoc();
33923461
const unsigned ArgIdx = (IID == Intrinsic::vector_reduce_fadd ||
33933462
IID == Intrinsic::vector_reduce_fmul)
33943463
? 1
33953464
: 0;
33963465
Value *Arg = II->getArgOperand(ArgIdx);
3397-
Value *V;
3398-
ArrayRef<int> Mask;
3399-
if (!isa<FixedVectorType>(Arg->getType()) || !CanBeReassociated ||
3400-
!match(Arg, m_Shuffle(m_Value(V), m_Undef(), m_Mask(Mask))) ||
3401-
!cast<ShuffleVectorInst>(Arg)->isSingleSource())
3402-
break;
3403-
int Sz = Mask.size();
3404-
SmallBitVector UsedIndices(Sz);
3405-
for (int Idx : Mask) {
3406-
if (Idx == PoisonMaskElem || UsedIndices.test(Idx))
3407-
break;
3408-
UsedIndices.set(Idx);
3409-
}
3410-
// Can remove shuffle iff just shuffled elements, no repeats, undefs, or
3411-
// other changes.
3412-
if (UsedIndices.all()) {
3413-
replaceUse(II->getOperandUse(ArgIdx), V);
3466+
if (Value *NewOp = simplifyReductionOperand(Arg, CanReorderLanes)) {
3467+
replaceUse(II->getOperandUse(ArgIdx), NewOp);
34143468
return nullptr;
34153469
}
34163470
break;

llvm/test/Transforms/InstCombine/vector-logical-reductions.ll

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,77 @@ define i1 @reduction_logical_and(<4 x i1> %x) {
2121
ret i1 %r
2222
}
2323

24+
define i1 @reduction_logical_or_reverse_nxv2i1(<vscale x 2 x i1> %p) {
25+
; CHECK-LABEL: @reduction_logical_or_reverse_nxv2i1(
26+
; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
27+
; CHECK-NEXT: ret i1 [[RED]]
28+
;
29+
%rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
30+
%red = call i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1> %rev)
31+
ret i1 %red
32+
}
33+
34+
define i1 @reduction_logical_or_reverse_v2i1(<2 x i1> %p) {
35+
; CHECK-LABEL: @reduction_logical_or_reverse_v2i1(
36+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
37+
; CHECK-NEXT: [[RED:%.*]] = icmp ne i2 [[TMP1]], 0
38+
; CHECK-NEXT: ret i1 [[RED]]
39+
;
40+
%rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
41+
%red = call i1 @llvm.vector.reduce.or.v2i1(<2 x i1> %rev)
42+
ret i1 %red
43+
}
44+
45+
define i1 @reduction_logical_and_reverse_nxv2i1(<vscale x 2 x i1> %p) {
46+
; CHECK-LABEL: @reduction_logical_and_reverse_nxv2i1(
47+
; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
48+
; CHECK-NEXT: ret i1 [[RED]]
49+
;
50+
%rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
51+
%red = call i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1> %rev)
52+
ret i1 %red
53+
}
54+
55+
define i1 @reduction_logical_and_reverse_v2i1(<2 x i1> %p) {
56+
; CHECK-LABEL: @reduction_logical_and_reverse_v2i1(
57+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
58+
; CHECK-NEXT: [[RED:%.*]] = icmp eq i2 [[TMP1]], -1
59+
; CHECK-NEXT: ret i1 [[RED]]
60+
;
61+
%rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
62+
%red = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %rev)
63+
ret i1 %red
64+
}
65+
66+
define i1 @reduction_logical_xor_reverse_nxv2i1(<vscale x 2 x i1> %p) {
67+
; CHECK-LABEL: @reduction_logical_xor_reverse_nxv2i1(
68+
; CHECK-NEXT: [[RED:%.*]] = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> [[P:%.*]])
69+
; CHECK-NEXT: ret i1 [[RED]]
70+
;
71+
%rev = call <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1> %p)
72+
%red = call i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1> %rev)
73+
ret i1 %red
74+
}
75+
76+
define i1 @reduction_logical_xor_reverse_v2i1(<2 x i1> %p) {
77+
; CHECK-LABEL: @reduction_logical_xor_reverse_v2i1(
78+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i1> [[P:%.*]] to i2
79+
; CHECK-NEXT: [[TMP2:%.*]] = call range(i2 0, -1) i2 @llvm.ctpop.i2(i2 [[TMP1]])
80+
; CHECK-NEXT: [[RED:%.*]] = trunc i2 [[TMP2]] to i1
81+
; CHECK-NEXT: ret i1 [[RED]]
82+
;
83+
%rev = call <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1> %p)
84+
%red = call i1 @llvm.vector.reduce.xor.v2i1(<2 x i1> %rev)
85+
ret i1 %red
86+
}
87+
2488
declare i1 @llvm.vector.reduce.or.v4i1(<4 x i1>)
89+
declare i1 @llvm.vector.reduce.or.nxv2i1(<vscale x 2 x i1>)
90+
declare i1 @llvm.vector.reduce.or.v2i1(<2 x i1>)
2591
declare i1 @llvm.vector.reduce.and.v4i1(<4 x i1>)
92+
declare i1 @llvm.vector.reduce.and.nxv2i1(<vscale x 2 x i1>)
93+
declare i1 @llvm.vector.reduce.and.v2i1(<2 x i1>)
94+
declare i1 @llvm.vector.reduce.xor.nxv2i1(<vscale x 2 x i1>)
95+
declare i1 @llvm.vector.reduce.xor.v2i1(<2 x i1>)
96+
declare <vscale x 2 x i1> @llvm.vector.reverse.nxv2i1(<vscale x 2 x i1>)
97+
declare <2 x i1> @llvm.vector.reverse.v2i1(<2 x i1>)

0 commit comments

Comments
 (0)