Skip to content

Commit b4b0c02

Browse files
authored
[SLP][REVEC] Make tryToReduce and related functions support vector instructions. (#102327)
1 parent a27f40e commit b4b0c02

File tree

2 files changed

+125
-5
lines changed

2 files changed

+125
-5
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17946,8 +17946,37 @@ class HorizontalReduction {
1794617946
SameValuesCounter, TrackedToOrig);
1794717947
}
1794817948

17949-
Value *ReducedSubTree =
17950-
emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
17949+
Value *ReducedSubTree;
17950+
Type *ScalarTy = VL.front()->getType();
17951+
if (isa<FixedVectorType>(ScalarTy)) {
17952+
assert(SLPReVec && "FixedVectorType is not expected.");
17953+
unsigned ScalarTyNumElements = getNumElements(ScalarTy);
17954+
ReducedSubTree = PoisonValue::get(FixedVectorType::get(
17955+
VectorizedRoot->getType()->getScalarType(), ScalarTyNumElements));
17956+
for (unsigned I : seq<unsigned>(ScalarTyNumElements)) {
17957+
// Do reduction for each lane.
17958+
// e.g., do reduce add for
17959+
// VL[0] = <4 x Ty> <a, b, c, d>
17960+
// VL[1] = <4 x Ty> <e, f, g, h>
17961+
// Lane[0] = <2 x Ty> <a, e>
17962+
// Lane[1] = <2 x Ty> <b, f>
17963+
// Lane[2] = <2 x Ty> <c, g>
17964+
// Lane[3] = <2 x Ty> <d, h>
17965+
// result[0] = reduce add Lane[0]
17966+
// result[1] = reduce add Lane[1]
17967+
// result[2] = reduce add Lane[2]
17968+
// result[3] = reduce add Lane[3]
17969+
SmallVector<int, 16> Mask =
17970+
createStrideMask(I, ScalarTyNumElements, VL.size());
17971+
Value *Lane = Builder.CreateShuffleVector(VectorizedRoot, Mask);
17972+
ReducedSubTree = Builder.CreateInsertElement(
17973+
ReducedSubTree, emitReduction(Lane, Builder, ReduxWidth, TTI),
17974+
I);
17975+
}
17976+
} else {
17977+
ReducedSubTree =
17978+
emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI);
17979+
}
1795117980
if (ReducedSubTree->getType() != VL.front()->getType()) {
1795217981
assert(ReducedSubTree->getType() != VL.front()->getType() &&
1795317982
"Expected different reduction type.");
@@ -18175,9 +18204,25 @@ class HorizontalReduction {
1817518204
case RecurKind::FAdd:
1817618205
case RecurKind::FMul: {
1817718206
unsigned RdxOpcode = RecurrenceDescriptor::getOpcode(RdxKind);
18178-
if (!AllConsts)
18179-
VectorCost =
18180-
TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF, CostKind);
18207+
if (!AllConsts) {
18208+
if (auto *VecTy = dyn_cast<FixedVectorType>(ScalarTy)) {
18209+
assert(SLPReVec && "FixedVectorType is not expected.");
18210+
unsigned ScalarTyNumElements = VecTy->getNumElements();
18211+
for (unsigned I : seq<unsigned>(ReducedVals.size())) {
18212+
VectorCost += TTI->getShuffleCost(
18213+
TTI::SK_PermuteSingleSrc, VectorTy,
18214+
createStrideMask(I, ScalarTyNumElements, ReducedVals.size()));
18215+
VectorCost += TTI->getArithmeticReductionCost(RdxOpcode, VecTy, FMF,
18216+
CostKind);
18217+
}
18218+
VectorCost += TTI->getScalarizationOverhead(
18219+
VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
18220+
/*Extract*/ false, TTI::TCK_RecipThroughput);
18221+
} else {
18222+
VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy, FMF,
18223+
CostKind);
18224+
}
18225+
}
1818118226
ScalarCost = EvaluateScalarCost([&]() {
1818218227
return TTI->getArithmeticInstrCost(RdxOpcode, ScalarTy, CostKind);
1818318228
});

llvm/test/Transforms/SLPVectorizer/revec.ll

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,78 @@ entry:
147147
%5 = icmp ult <4 x ptr> %3, %4
148148
ret void
149149
}
150+
151+
define <4 x i1> @test6(ptr %in1, ptr %in2) {
152+
; CHECK-LABEL: @test6(
153+
; CHECK-NEXT: entry:
154+
; CHECK-NEXT: [[TMP0:%.*]] = load <4 x i32>, ptr [[IN1:%.*]], align 4
155+
; CHECK-NEXT: [[TMP1:%.*]] = load <4 x i16>, ptr [[IN2:%.*]], align 2
156+
; CHECK-NEXT: [[TMP2:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> poison, <4 x i32> poison, i64 4)
157+
; CHECK-NEXT: [[TMP3:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP2]], <4 x i32> poison, i64 8)
158+
; CHECK-NEXT: [[TMP4:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP3]], <4 x i32> poison, i64 12)
159+
; CHECK-NEXT: [[TMP5:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP4]], <4 x i32> [[TMP0]], i64 0)
160+
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <16 x i32> [[TMP5]], <16 x i32> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
161+
; CHECK-NEXT: [[TMP7:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> poison, <4 x i32> zeroinitializer, i64 0)
162+
; CHECK-NEXT: [[TMP8:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP7]], <4 x i32> zeroinitializer, i64 4)
163+
; CHECK-NEXT: [[TMP9:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP8]], <4 x i32> zeroinitializer, i64 8)
164+
; CHECK-NEXT: [[TMP10:%.*]] = call <16 x i32> @llvm.vector.insert.v16i32.v4i32(<16 x i32> [[TMP9]], <4 x i32> zeroinitializer, i64 12)
165+
; CHECK-NEXT: [[TMP11:%.*]] = icmp ugt <16 x i32> [[TMP6]], [[TMP10]]
166+
; CHECK-NEXT: [[TMP12:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> poison, <4 x i16> poison, i64 4)
167+
; CHECK-NEXT: [[TMP13:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP12]], <4 x i16> poison, i64 8)
168+
; CHECK-NEXT: [[TMP14:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP13]], <4 x i16> poison, i64 12)
169+
; CHECK-NEXT: [[TMP15:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP14]], <4 x i16> [[TMP1]], i64 0)
170+
; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <16 x i16> [[TMP15]], <16 x i16> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3, i32 0, i32 1, i32 2, i32 3>
171+
; CHECK-NEXT: [[TMP17:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> poison, <4 x i16> zeroinitializer, i64 0)
172+
; CHECK-NEXT: [[TMP18:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP17]], <4 x i16> zeroinitializer, i64 4)
173+
; CHECK-NEXT: [[TMP19:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP18]], <4 x i16> zeroinitializer, i64 8)
174+
; CHECK-NEXT: [[TMP20:%.*]] = call <16 x i16> @llvm.vector.insert.v16i16.v4i16(<16 x i16> [[TMP19]], <4 x i16> zeroinitializer, i64 12)
175+
; CHECK-NEXT: [[TMP21:%.*]] = icmp eq <16 x i16> [[TMP16]], [[TMP20]]
176+
; CHECK-NEXT: [[TMP22:%.*]] = and <16 x i1> [[TMP11]], [[TMP21]]
177+
; CHECK-NEXT: [[TMP23:%.*]] = icmp ugt <16 x i32> [[TMP6]], [[TMP10]]
178+
; CHECK-NEXT: [[TMP24:%.*]] = and <16 x i1> [[TMP22]], [[TMP23]]
179+
; CHECK-NEXT: [[TMP25:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 0, i32 4, i32 8, i32 12>
180+
; CHECK-NEXT: [[TMP26:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP25]])
181+
; CHECK-NEXT: [[TMP27:%.*]] = insertelement <4 x i1> poison, i1 [[TMP26]], i64 0
182+
; CHECK-NEXT: [[TMP28:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 1, i32 5, i32 9, i32 13>
183+
; CHECK-NEXT: [[TMP29:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP28]])
184+
; CHECK-NEXT: [[TMP30:%.*]] = insertelement <4 x i1> [[TMP27]], i1 [[TMP29]], i64 1
185+
; CHECK-NEXT: [[TMP31:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 2, i32 6, i32 10, i32 14>
186+
; CHECK-NEXT: [[TMP32:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP31]])
187+
; CHECK-NEXT: [[TMP33:%.*]] = insertelement <4 x i1> [[TMP30]], i1 [[TMP32]], i64 2
188+
; CHECK-NEXT: [[TMP34:%.*]] = shufflevector <16 x i1> [[TMP24]], <16 x i1> poison, <4 x i32> <i32 3, i32 7, i32 11, i32 15>
189+
; CHECK-NEXT: [[TMP35:%.*]] = call i1 @llvm.vector.reduce.or.v4i1(<4 x i1> [[TMP34]])
190+
; CHECK-NEXT: [[TMP36:%.*]] = insertelement <4 x i1> [[TMP33]], i1 [[TMP35]], i64 3
191+
; CHECK-NEXT: [[VBSL:%.*]] = select <4 x i1> [[TMP36]], <4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>
192+
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt <4 x i32> [[VBSL]], <i32 2, i32 3, i32 4, i32 5>
193+
; CHECK-NEXT: ret <4 x i1> [[CMP]]
194+
;
195+
entry:
196+
%0 = load <4 x i32>, ptr %in1, align 4
197+
%1 = load <4 x i16>, ptr %in2, align 2
198+
%cmp000 = icmp ugt <4 x i32> %0, zeroinitializer
199+
%cmp001 = icmp ugt <4 x i32> %0, zeroinitializer
200+
%cmp002 = icmp ugt <4 x i32> %0, zeroinitializer
201+
%cmp003 = icmp ugt <4 x i32> %0, zeroinitializer
202+
%cmp100 = icmp eq <4 x i16> %1, zeroinitializer
203+
%cmp101 = icmp eq <4 x i16> %1, zeroinitializer
204+
%cmp102 = icmp eq <4 x i16> %1, zeroinitializer
205+
%cmp103 = icmp eq <4 x i16> %1, zeroinitializer
206+
%and.cmp0 = and <4 x i1> %cmp000, %cmp100
207+
%and.cmp1 = and <4 x i1> %cmp001, %cmp101
208+
%and.cmp2 = and <4 x i1> %cmp002, %cmp102
209+
%and.cmp3 = and <4 x i1> %cmp003, %cmp103
210+
%cmp004 = icmp ugt <4 x i32> %0, zeroinitializer
211+
%cmp005 = icmp ugt <4 x i32> %0, zeroinitializer
212+
%cmp006 = icmp ugt <4 x i32> %0, zeroinitializer
213+
%cmp007 = icmp ugt <4 x i32> %0, zeroinitializer
214+
%and.cmp4 = and <4 x i1> %and.cmp0, %cmp004
215+
%and.cmp5 = and <4 x i1> %and.cmp1, %cmp005
216+
%and.cmp6 = and <4 x i1> %and.cmp2, %cmp006
217+
%and.cmp7 = and <4 x i1> %and.cmp3, %cmp007
218+
%or0 = or <4 x i1> %and.cmp5, %and.cmp4
219+
%or1 = or <4 x i1> %or0, %and.cmp6
220+
%or2 = or <4 x i1> %or1, %and.cmp7
221+
%vbsl = select <4 x i1> %or2, <4 x i32> <i32 1, i32 2, i32 3, i32 4>, <4 x i32> <i32 5, i32 6, i32 7, i32 8>
222+
%cmp = icmp ugt <4 x i32> %vbsl, <i32 2, i32 3, i32 4, i32 5>
223+
ret <4 x i1> %cmp
224+
}

0 commit comments

Comments
 (0)