Skip to content

Commit d9cc5d8

Browse files
committed
[AArch64][SVE] Combine bitcasts of predicate types with vector inserts/extracts of loads/stores
An insert subvector that is inserting the result of a vector predicate sized load into undef at index 0, whose result is casted to a predicate type, can be combined into a direct predicate load. Likewise the same applies to extract subvector but in reverse. The purpose of this optimization is to clean up cases that will be introduced in a later patch where casts to/from predicate types from i8 types will use insert subvector, rather than going through memory early. This optimization is done in SVEIntrinsicOpts rather than InstCombine to re-introduce scalable loads as late as possible, to give other optimizations the best chance possible to do a good job. Differential Revision: https://reviews.llvm.org/D106549
1 parent 478c71b commit d9cc5d8

File tree

3 files changed

+361
-0
lines changed

3 files changed

+361
-0
lines changed

llvm/lib/Target/AArch64/SVEIntrinsicOpts.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ struct SVEIntrinsicOpts : public ModulePass {
5959
bool coalescePTrueIntrinsicCalls(BasicBlock &BB,
6060
SmallSetVector<IntrinsicInst *, 4> &PTrues);
6161
bool optimizePTrueIntrinsicCalls(SmallSetVector<Function *, 4> &Functions);
62+
bool optimizePredicateStore(Instruction *I);
63+
bool optimizePredicateLoad(Instruction *I);
64+
65+
bool optimizeInstructions(SmallSetVector<Function *, 4> &Functions);
6266

6367
/// Operates at the function-scope. I.e., optimizations are applied local to
6468
/// the functions themselves.
@@ -276,11 +280,166 @@ bool SVEIntrinsicOpts::optimizePTrueIntrinsicCalls(
276280
return Changed;
277281
}
278282

283+
// This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
284+
// scalable stores as late as possible
285+
bool SVEIntrinsicOpts::optimizePredicateStore(Instruction *I) {
286+
auto *F = I->getFunction();
287+
auto Attr = F->getFnAttribute(Attribute::VScaleRange);
288+
if (!Attr.isValid())
289+
return false;
290+
291+
unsigned MinVScale, MaxVScale;
292+
std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs();
293+
// The transform needs to know the exact runtime length of scalable vectors
294+
if (MinVScale != MaxVScale || MinVScale == 0)
295+
return false;
296+
297+
auto *PredType =
298+
ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
299+
auto *FixedPredType =
300+
FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
301+
302+
// If we have a store..
303+
auto *Store = dyn_cast<StoreInst>(I);
304+
if (!Store || !Store->isSimple())
305+
return false;
306+
307+
// ..that is storing a predicate vector sized worth of bits..
308+
if (Store->getOperand(0)->getType() != FixedPredType)
309+
return false;
310+
311+
// ..where the value stored comes from a vector extract..
312+
auto *IntrI = dyn_cast<IntrinsicInst>(Store->getOperand(0));
313+
if (!IntrI ||
314+
IntrI->getIntrinsicID() != Intrinsic::experimental_vector_extract)
315+
return false;
316+
317+
// ..that is extracting from index 0..
318+
if (!cast<ConstantInt>(IntrI->getOperand(1))->isZero())
319+
return false;
320+
321+
// ..where the value being extract from comes from a bitcast
322+
auto *BitCast = dyn_cast<BitCastInst>(IntrI->getOperand(0));
323+
if (!BitCast)
324+
return false;
325+
326+
// ..and the bitcast is casting from predicate type
327+
if (BitCast->getOperand(0)->getType() != PredType)
328+
return false;
329+
330+
IRBuilder<> Builder(I->getContext());
331+
Builder.SetInsertPoint(I);
332+
333+
auto *PtrBitCast = Builder.CreateBitCast(
334+
Store->getPointerOperand(),
335+
PredType->getPointerTo(Store->getPointerAddressSpace()));
336+
Builder.CreateStore(BitCast->getOperand(0), PtrBitCast);
337+
338+
Store->eraseFromParent();
339+
if (IntrI->getNumUses() == 0)
340+
IntrI->eraseFromParent();
341+
if (BitCast->getNumUses() == 0)
342+
BitCast->eraseFromParent();
343+
344+
return true;
345+
}
346+
347+
// This is done in SVEIntrinsicOpts rather than InstCombine so that we introduce
348+
// scalable loads as late as possible
349+
bool SVEIntrinsicOpts::optimizePredicateLoad(Instruction *I) {
350+
auto *F = I->getFunction();
351+
auto Attr = F->getFnAttribute(Attribute::VScaleRange);
352+
if (!Attr.isValid())
353+
return false;
354+
355+
unsigned MinVScale, MaxVScale;
356+
std::tie(MinVScale, MaxVScale) = Attr.getVScaleRangeArgs();
357+
// The transform needs to know the exact runtime length of scalable vectors
358+
if (MinVScale != MaxVScale || MinVScale == 0)
359+
return false;
360+
361+
auto *PredType =
362+
ScalableVectorType::get(Type::getInt1Ty(I->getContext()), 16);
363+
auto *FixedPredType =
364+
FixedVectorType::get(Type::getInt8Ty(I->getContext()), MinVScale * 2);
365+
366+
// If we have a bitcast..
367+
auto *BitCast = dyn_cast<BitCastInst>(I);
368+
if (!BitCast || BitCast->getType() != PredType)
369+
return false;
370+
371+
// ..whose operand is a vector_insert..
372+
auto *IntrI = dyn_cast<IntrinsicInst>(BitCast->getOperand(0));
373+
if (!IntrI ||
374+
IntrI->getIntrinsicID() != Intrinsic::experimental_vector_insert)
375+
return false;
376+
377+
// ..that is inserting into index zero of an undef vector..
378+
if (!isa<UndefValue>(IntrI->getOperand(0)) ||
379+
!cast<ConstantInt>(IntrI->getOperand(2))->isZero())
380+
return false;
381+
382+
// ..where the value inserted comes from a load..
383+
auto *Load = dyn_cast<LoadInst>(IntrI->getOperand(1));
384+
if (!Load || !Load->isSimple())
385+
return false;
386+
387+
// ..that is loading a predicate vector sized worth of bits..
388+
if (Load->getType() != FixedPredType)
389+
return false;
390+
391+
IRBuilder<> Builder(I->getContext());
392+
Builder.SetInsertPoint(Load);
393+
394+
auto *PtrBitCast = Builder.CreateBitCast(
395+
Load->getPointerOperand(),
396+
PredType->getPointerTo(Load->getPointerAddressSpace()));
397+
auto *LoadPred = Builder.CreateLoad(PredType, PtrBitCast);
398+
399+
BitCast->replaceAllUsesWith(LoadPred);
400+
BitCast->eraseFromParent();
401+
if (IntrI->getNumUses() == 0)
402+
IntrI->eraseFromParent();
403+
if (Load->getNumUses() == 0)
404+
Load->eraseFromParent();
405+
406+
return true;
407+
}
408+
409+
bool SVEIntrinsicOpts::optimizeInstructions(
410+
SmallSetVector<Function *, 4> &Functions) {
411+
bool Changed = false;
412+
413+
for (auto *F : Functions) {
414+
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>(*F).getDomTree();
415+
416+
// Traverse the DT with an rpo walk so we see defs before uses, allowing
417+
// simplification to be done incrementally.
418+
BasicBlock *Root = DT->getRoot();
419+
ReversePostOrderTraversal<BasicBlock *> RPOT(Root);
420+
for (auto *BB : RPOT) {
421+
for (Instruction &I : make_early_inc_range(*BB)) {
422+
switch (I.getOpcode()) {
423+
case Instruction::Store:
424+
Changed |= optimizePredicateStore(&I);
425+
break;
426+
case Instruction::BitCast:
427+
Changed |= optimizePredicateLoad(&I);
428+
break;
429+
}
430+
}
431+
}
432+
}
433+
434+
return Changed;
435+
}
436+
279437
bool SVEIntrinsicOpts::optimizeFunctions(
280438
SmallSetVector<Function *, 4> &Functions) {
281439
bool Changed = false;
282440

283441
Changed |= optimizePTrueIntrinsicCalls(Functions);
442+
Changed |= optimizeInstructions(Functions);
284443

285444
return Changed;
286445
}
@@ -297,6 +456,8 @@ bool SVEIntrinsicOpts::runOnModule(Module &M) {
297456
continue;
298457

299458
switch (F.getIntrinsicID()) {
459+
case Intrinsic::experimental_vector_extract:
460+
case Intrinsic::experimental_vector_insert:
300461
case Intrinsic::aarch64_sve_ptrue:
301462
for (User *U : F.users())
302463
Functions.insert(cast<Instruction>(U)->getFunction());
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
; RUN: opt -S -aarch64-sve-intrinsic-opts < %s | FileCheck %s
2+
3+
target triple = "aarch64-unknown-linux-gnu"
4+
5+
define void @pred_store_v2i8(<vscale x 16 x i1> %pred, <2 x i8>* %addr) #0 {
6+
; CHECK-LABEL: @pred_store_v2i8(
7+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i8>* %addr to <vscale x 16 x i1>*
8+
; CHECK-NEXT: store <vscale x 16 x i1> %pred, <vscale x 16 x i1>* [[TMP1]]
9+
; CHECK-NEXT: ret void
10+
%bitcast = bitcast <vscale x 16 x i1> %pred to <vscale x 2 x i8>
11+
%extract = tail call <2 x i8> @llvm.experimental.vector.extract.v2i8.nxv2i8(<vscale x 2 x i8> %bitcast, i64 0)
12+
store <2 x i8> %extract, <2 x i8>* %addr, align 4
13+
ret void
14+
}
15+
16+
define void @pred_store_v4i8(<vscale x 16 x i1> %pred, <4 x i8>* %addr) #1 {
17+
; CHECK-LABEL: @pred_store_v4i8(
18+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8>* %addr to <vscale x 16 x i1>*
19+
; CHECK-NEXT: store <vscale x 16 x i1> %pred, <vscale x 16 x i1>* [[TMP1]]
20+
; CHECK-NEXT: ret void
21+
%bitcast = bitcast <vscale x 16 x i1> %pred to <vscale x 2 x i8>
22+
%extract = tail call <4 x i8> @llvm.experimental.vector.extract.v4i8.nxv2i8(<vscale x 2 x i8> %bitcast, i64 0)
23+
store <4 x i8> %extract, <4 x i8>* %addr, align 4
24+
ret void
25+
}
26+
27+
define void @pred_store_v8i8(<vscale x 16 x i1> %pred, <8 x i8>* %addr) #2 {
28+
; CHECK-LABEL: @pred_store_v8i8(
29+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i8>* %addr to <vscale x 16 x i1>*
30+
; CHECK-NEXT: store <vscale x 16 x i1> %pred, <vscale x 16 x i1>* [[TMP1]]
31+
; CHECK-NEXT: ret void
32+
%bitcast = bitcast <vscale x 16 x i1> %pred to <vscale x 2 x i8>
33+
%extract = tail call <8 x i8> @llvm.experimental.vector.extract.v8i8.nxv2i8(<vscale x 2 x i8> %bitcast, i64 0)
34+
store <8 x i8> %extract, <8 x i8>* %addr, align 4
35+
ret void
36+
}
37+
38+
39+
; Check that too small of a vscale prevents optimization
40+
define void @pred_store_neg1(<vscale x 16 x i1> %pred, <4 x i8>* %addr) #0 {
41+
; CHECK-LABEL: @pred_store_neg1(
42+
; CHECK: call <4 x i8> @llvm.experimental.vector.extract
43+
%bitcast = bitcast <vscale x 16 x i1> %pred to <vscale x 2 x i8>
44+
%extract = tail call <4 x i8> @llvm.experimental.vector.extract.v4i8.nxv2i8(<vscale x 2 x i8> %bitcast, i64 0)
45+
store <4 x i8> %extract, <4 x i8>* %addr, align 4
46+
ret void
47+
}
48+
49+
; Check that too large of a vscale prevents optimization
50+
define void @pred_store_neg2(<vscale x 16 x i1> %pred, <4 x i8>* %addr) #2 {
51+
; CHECK-LABEL: @pred_store_neg2(
52+
; CHECK: call <4 x i8> @llvm.experimental.vector.extract
53+
%bitcast = bitcast <vscale x 16 x i1> %pred to <vscale x 2 x i8>
54+
%extract = tail call <4 x i8> @llvm.experimental.vector.extract.v4i8.nxv2i8(<vscale x 2 x i8> %bitcast, i64 0)
55+
store <4 x i8> %extract, <4 x i8>* %addr, align 4
56+
ret void
57+
}
58+
59+
; Check that a non-zero index prevents optimization
60+
define void @pred_store_neg3(<vscale x 16 x i1> %pred, <4 x i8>* %addr) #1 {
61+
; CHECK-LABEL: @pred_store_neg3(
62+
; CHECK: call <4 x i8> @llvm.experimental.vector.extract
63+
%bitcast = bitcast <vscale x 16 x i1> %pred to <vscale x 2 x i8>
64+
%extract = tail call <4 x i8> @llvm.experimental.vector.extract.v4i8.nxv2i8(<vscale x 2 x i8> %bitcast, i64 4)
65+
store <4 x i8> %extract, <4 x i8>* %addr, align 4
66+
ret void
67+
}
68+
69+
; Check that differing vscale min/max prevents optimization
70+
define void @pred_store_neg4(<vscale x 16 x i1> %pred, <4 x i8>* %addr) #3 {
71+
; CHECK-LABEL: @pred_store_neg4(
72+
; CHECK: call <4 x i8> @llvm.experimental.vector.extract
73+
%bitcast = bitcast <vscale x 16 x i1> %pred to <vscale x 2 x i8>
74+
%extract = tail call <4 x i8> @llvm.experimental.vector.extract.v4i8.nxv2i8(<vscale x 2 x i8> %bitcast, i64 0)
75+
store <4 x i8> %extract, <4 x i8>* %addr, align 4
76+
ret void
77+
}
78+
79+
declare <2 x i8> @llvm.experimental.vector.extract.v2i8.nxv2i8(<vscale x 2 x i8>, i64)
80+
declare <4 x i8> @llvm.experimental.vector.extract.v4i8.nxv2i8(<vscale x 2 x i8>, i64)
81+
declare <8 x i8> @llvm.experimental.vector.extract.v8i8.nxv2i8(<vscale x 2 x i8>, i64)
82+
83+
attributes #0 = { "target-features"="+sve" vscale_range(1,1) }
84+
attributes #1 = { "target-features"="+sve" vscale_range(2,2) }
85+
attributes #2 = { "target-features"="+sve" vscale_range(4,4) }
86+
attributes #3 = { "target-features"="+sve" vscale_range(2,4) }
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
; RUN: opt -S -aarch64-sve-intrinsic-opts < %s | FileCheck %s
2+
3+
target triple = "aarch64-unknown-linux-gnu"
4+
5+
define <vscale x 16 x i1> @pred_load_v2i8(<2 x i8>* %addr) #0 {
6+
; CHECK-LABEL: @pred_load_v2i8(
7+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i8>* %addr to <vscale x 16 x i1>*
8+
; CHECK-NEXT: [[TMP2:%.*]] = load <vscale x 16 x i1>, <vscale x 16 x i1>* [[TMP1]]
9+
; CHECK-NEXT: ret <vscale x 16 x i1> [[TMP2]]
10+
%load = load <2 x i8>, <2 x i8>* %addr, align 4
11+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %load, i64 0)
12+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
13+
ret <vscale x 16 x i1> %ret
14+
}
15+
16+
define <vscale x 16 x i1> @pred_load_v4i8(<4 x i8>* %addr) #1 {
17+
; CHECK-LABEL: @pred_load_v4i8(
18+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8>* %addr to <vscale x 16 x i1>*
19+
; CHECK-NEXT: [[TMP2:%.*]] = load <vscale x 16 x i1>, <vscale x 16 x i1>* [[TMP1]]
20+
; CHECK-NEXT: ret <vscale x 16 x i1> [[TMP2]]
21+
%load = load <4 x i8>, <4 x i8>* %addr, align 4
22+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v4i8(<vscale x 2 x i8> undef, <4 x i8> %load, i64 0)
23+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
24+
ret <vscale x 16 x i1> %ret
25+
}
26+
27+
define <vscale x 16 x i1> @pred_load_v8i8(<8 x i8>* %addr) #2 {
28+
; CHECK-LABEL: @pred_load_v8i8(
29+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i8>* %addr to <vscale x 16 x i1>*
30+
; CHECK-NEXT: [[TMP2:%.*]] = load <vscale x 16 x i1>, <vscale x 16 x i1>* [[TMP1]]
31+
; CHECK-NEXT: ret <vscale x 16 x i1> [[TMP2]]
32+
%load = load <8 x i8>, <8 x i8>* %addr, align 4
33+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v8i8(<vscale x 2 x i8> undef, <8 x i8> %load, i64 0)
34+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
35+
ret <vscale x 16 x i1> %ret
36+
}
37+
38+
; Ensure the insertion point is at the load
39+
define <vscale x 16 x i1> @pred_load_insertion_point(<2 x i8>* %addr) #0 {
40+
; CHECK-LABEL: @pred_load_insertion_point(
41+
; CHECK-NEXT: entry:
42+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i8>* %addr to <vscale x 16 x i1>*
43+
; CHECK-NEXT: [[TMP2:%.*]] = load <vscale x 16 x i1>, <vscale x 16 x i1>* [[TMP1]]
44+
; CHECK-NEXT: br label %bb1
45+
; CHECK: bb1:
46+
; CHECK-NEXT: ret <vscale x 16 x i1> [[TMP2]]
47+
entry:
48+
%load = load <2 x i8>, <2 x i8>* %addr, align 4
49+
br label %bb1
50+
51+
bb1:
52+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8> undef, <2 x i8> %load, i64 0)
53+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
54+
ret <vscale x 16 x i1> %ret
55+
}
56+
57+
; Check that too small of a vscale prevents optimization
58+
define <vscale x 16 x i1> @pred_load_neg1(<4 x i8>* %addr) #0 {
59+
; CHECK-LABEL: @pred_load_neg1(
60+
; CHECK: call <vscale x 2 x i8> @llvm.experimental.vector.insert
61+
%load = load <4 x i8>, <4 x i8>* %addr, align 4
62+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v4i8(<vscale x 2 x i8> undef, <4 x i8> %load, i64 0)
63+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
64+
ret <vscale x 16 x i1> %ret
65+
}
66+
67+
; Check that too large of a vscale prevents optimization
68+
define <vscale x 16 x i1> @pred_load_neg2(<4 x i8>* %addr) #2 {
69+
; CHECK-LABEL: @pred_load_neg2(
70+
; CHECK: call <vscale x 2 x i8> @llvm.experimental.vector.insert
71+
%load = load <4 x i8>, <4 x i8>* %addr, align 4
72+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v4i8(<vscale x 2 x i8> undef, <4 x i8> %load, i64 0)
73+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
74+
ret <vscale x 16 x i1> %ret
75+
}
76+
77+
; Check that a non-zero index prevents optimization
78+
define <vscale x 16 x i1> @pred_load_neg3(<4 x i8>* %addr) #1 {
79+
; CHECK-LABEL: @pred_load_neg3(
80+
; CHECK: call <vscale x 2 x i8> @llvm.experimental.vector.insert
81+
%load = load <4 x i8>, <4 x i8>* %addr, align 4
82+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v4i8(<vscale x 2 x i8> undef, <4 x i8> %load, i64 4)
83+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
84+
ret <vscale x 16 x i1> %ret
85+
}
86+
87+
; Check that differing vscale min/max prevents optimization
88+
define <vscale x 16 x i1> @pred_load_neg4(<4 x i8>* %addr) #3 {
89+
; CHECK-LABEL: @pred_load_neg4(
90+
; CHECK: call <vscale x 2 x i8> @llvm.experimental.vector.insert
91+
%load = load <4 x i8>, <4 x i8>* %addr, align 4
92+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v4i8(<vscale x 2 x i8> undef, <4 x i8> %load, i64 0)
93+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
94+
ret <vscale x 16 x i1> %ret
95+
}
96+
97+
; Check that insertion into a non-undef vector prevents optimization
98+
define <vscale x 16 x i1> @pred_load_neg5(<4 x i8>* %addr, <vscale x 2 x i8> %passthru) #1 {
99+
; CHECK-LABEL: @pred_load_neg5(
100+
; CHECK: call <vscale x 2 x i8> @llvm.experimental.vector.insert
101+
%load = load <4 x i8>, <4 x i8>* %addr, align 4
102+
%insert = tail call <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v4i8(<vscale x 2 x i8> %passthru, <4 x i8> %load, i64 0)
103+
%ret = bitcast <vscale x 2 x i8> %insert to <vscale x 16 x i1>
104+
ret <vscale x 16 x i1> %ret
105+
}
106+
107+
declare <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v2i8(<vscale x 2 x i8>, <2 x i8>, i64)
108+
declare <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v4i8(<vscale x 2 x i8>, <4 x i8>, i64)
109+
declare <vscale x 2 x i8> @llvm.experimental.vector.insert.nxv2i8.v8i8(<vscale x 2 x i8>, <8 x i8>, i64)
110+
111+
attributes #0 = { "target-features"="+sve" vscale_range(1,1) }
112+
attributes #1 = { "target-features"="+sve" vscale_range(2,2) }
113+
attributes #2 = { "target-features"="+sve" vscale_range(4,4) }
114+
attributes #3 = { "target-features"="+sve" vscale_range(2,4) }

0 commit comments

Comments
 (0)