Skip to content

Commit cea940e

Browse files
authored
[SYCL-MLIR]: Teach LICM to hoist more invariant aliased load/stores operations (#8399)
This PR enhances LICM as follows: - allow aliased store operation that dominates load operation to be hoisted - ensure store operation that doesn't dominate the aliased load operation is not hoisted --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 8775b9f commit cea940e

File tree

2 files changed

+126
-42
lines changed

2 files changed

+126
-42
lines changed

polygeist/lib/Dialect/Polygeist/Transforms/LICM.cpp

Lines changed: 63 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@ class OperationSideEffects {
3939
const OperationSideEffects &);
4040

4141
public:
42-
OperationSideEffects(const Operation &op, const AliasAnalysis &aliasAnalysis)
43-
: op(op), aliasAnalysis(aliasAnalysis) {
42+
OperationSideEffects(const Operation &op, const AliasAnalysis &aliasAnalysis,
43+
const DominanceInfo &domInfo)
44+
: op(op), aliasAnalysis(aliasAnalysis), domInfo(domInfo) {
4445
if (auto memEffect = dyn_cast<MemoryEffectOpInterface>(op)) {
4546
SmallVector<MemoryEffects::EffectInstance, 1> effects;
4647
memEffect.getEffects(effects);
@@ -103,6 +104,8 @@ class OperationSideEffects {
103104
private:
104105
const Operation &op; /// Operation associated with the side effects.
105106
const AliasAnalysis &aliasAnalysis; /// Alias Analysis reference.
107+
const DominanceInfo &domInfo; /// Dominance information reference.
108+
106109
/// Side effects associated with reading resources.
107110
SmallVector<MemoryEffects::EffectInstance> readResources;
108111
/// Side effects associated with writing resources.
@@ -293,7 +296,7 @@ bool OperationSideEffects::conflictsWith(const Operation &other) const {
293296
// If the given operation has side effects, check whether they conflict with
294297
// the side effects summarized in this class.
295298
if (auto MEI = dyn_cast<MemoryEffectOpInterface>(other)) {
296-
OperationSideEffects sideEffects(other, aliasAnalysis);
299+
OperationSideEffects sideEffects(other, aliasAnalysis, domInfo);
297300

298301
// Checks for a conflicts on the given resource 'res' by applying the
299302
// supplied predicate function 'hasConflict'.
@@ -313,15 +316,15 @@ bool OperationSideEffects::conflictsWith(const Operation &other) const {
313316
[[maybe_unused]] auto printConflictingSideEffects =
314317
[](const MemoryEffects::EffectInstance &EI, AliasResult aliasRes,
315318
const Operation &other) {
316-
llvm::dbgs() << "Found conflicting side effect: {"
317-
<< EI.getResource()->getName() << ", " << EI.getValue()
318-
<< "}\n";
319-
llvm::dbgs().indent(2) << "with: " << other << "\n";
320-
llvm::dbgs().indent(2) << "aliasResult: " << aliasRes << "\n";
319+
llvm::dbgs().indent(2)
320+
<< "found conflicting side effect: {"
321+
<< EI.getResource()->getName() << ", " << EI.getValue() << "}\n";
322+
llvm::dbgs().indent(4) << "with: " << other << "\n";
323+
llvm::dbgs().indent(4) << "aliasResult: " << aliasRes << "\n";
321324
};
322325

323-
// Check whether the given operation 'other' writes (or allocates, or frees)
324-
// a resource that is read by the operation associated with this class.
326+
// Check whether the given operation 'other' allocates, writes, or frees a
327+
// resource that is read by the operation associated with this class.
325328
if (llvm::any_of(
326329
readResources, [&](const MemoryEffects::EffectInstance &readRes) {
327330
auto hasConflict = [&](const MemoryEffects::EffectInstance &EI) {
@@ -346,21 +349,35 @@ bool OperationSideEffects::conflictsWith(const Operation &other) const {
346349
// Check whether the given operation 'other' allocates, reads, writes or
347350
// frees a resource that is written by the operation associated with this
348351
// class.
349-
if (llvm::any_of(
350-
writeResources, [&](const MemoryEffects::EffectInstance &writeRes) {
351-
auto hasConflict = [&](const MemoryEffects::EffectInstance &EI) {
352-
AliasResult aliasRes =
353-
const_cast<AliasAnalysis &>(aliasAnalysis)
354-
.alias(EI.getValue(), writeRes.getValue());
355-
if (aliasRes.isNo())
356-
return false;
352+
if (llvm::any_of(writeResources, [&](const MemoryEffects::EffectInstance
353+
&writeRes) {
354+
auto hasConflict = [&](const MemoryEffects::EffectInstance &EI) {
355+
AliasResult aliasRes =
356+
const_cast<AliasAnalysis &>(aliasAnalysis)
357+
.alias(EI.getValue(), writeRes.getValue());
358+
if (aliasRes.isNo())
359+
return false;
360+
361+
// An aliased read operation doesn't prevent hoisting if it is
362+
// dominated by the write operation.
363+
if (isa<MemoryEffects::Read>(EI.getEffect()) &&
364+
domInfo.dominates(const_cast<Operation *>(&op),
365+
const_cast<Operation *>(&other))) {
366+
LLVM_DEBUG({
367+
printConflictingSideEffects(EI, aliasRes, other);
368+
llvm::dbgs().indent(2)
369+
<< "can be hoisted: aliased write operation dominates the "
370+
"read operation\n";
371+
});
372+
return false;
373+
}
357374

358-
LLVM_DEBUG(printConflictingSideEffects(EI, aliasRes, other));
359-
return true;
360-
};
375+
LLVM_DEBUG(printConflictingSideEffects(EI, aliasRes, other));
376+
return true;
377+
};
361378

362-
return checkForConflict(writeRes.getResource(), hasConflict);
363-
})) {
379+
return checkForConflict(writeRes.getResource(), hasConflict);
380+
})) {
364381
return true;
365382
}
366383

@@ -683,8 +700,9 @@ AffineIfOp AffineParallelGuardBuilder::createGuard() const {
683700
/// conflicts in the loop are given in \p willBeMoved.
684701
static bool hasConflictsInLoop(Operation &op, LoopLikeOpInterface loop,
685702
const SmallPtrSetImpl<Operation *> &willBeMoved,
686-
const AliasAnalysis &aliasAnalysis) {
687-
const OperationSideEffects sideEffects(op, aliasAnalysis);
703+
const AliasAnalysis &aliasAnalysis,
704+
const DominanceInfo &domInfo) {
705+
const OperationSideEffects sideEffects(op, aliasAnalysis, domInfo);
688706

689707
Optional<Operation *> conflictingOp =
690708
TypeSwitch<Operation *, Optional<Operation *>>((Operation *)loop)
@@ -704,13 +722,15 @@ static bool hasConflictsInLoop(Operation &op, LoopLikeOpInterface loop,
704722
if (conflictingOp.has_value()) {
705723
if (!willBeMoved.count(*conflictingOp))
706724
return true;
707-
LLVM_DEBUG(llvm::dbgs() << "OK: related operation will be hoisted\n");
725+
LLVM_DEBUG(llvm::dbgs().indent(2)
726+
<< "can be hoisted: conflicting operation will be hoisted\n");
708727
}
709728

710729
// Check whether the parent operation has conflicts on the loop.
711730
if (op.getParentOp() == loop)
712731
return false;
713-
if (hasConflictsInLoop(*op.getParentOp(), loop, willBeMoved, aliasAnalysis))
732+
if (hasConflictsInLoop(*op.getParentOp(), loop, willBeMoved, aliasAnalysis,
733+
domInfo))
714734
return true;
715735

716736
// If the parent operation is not guaranteed to execute its
@@ -736,7 +756,8 @@ static bool hasConflictsInLoop(Operation &op, LoopLikeOpInterface loop,
736756
/// to be loop invariant (and therefore will be moved outside of the loop).
737757
static bool canBeHoisted(Operation &op, LoopLikeOpInterface loop,
738758
const SmallPtrSetImpl<Operation *> &willBeMoved,
739-
const AliasAnalysis &aliasAnalysis) {
759+
const AliasAnalysis &aliasAnalysis,
760+
const DominanceInfo &domInfo) {
740761
// Returns true if the given value can be moved outside of the loop, and
741762
// false otherwise. A value cannot be moved outside of the loop if its
742763
// operands are not defined outside of the loop and cannot themselves be
@@ -783,7 +804,7 @@ static bool canBeHoisted(Operation &op, LoopLikeOpInterface loop,
783804
}
784805

785806
// Do not hoist operations that allocate a resource.
786-
const OperationSideEffects sideEffects(op, aliasAnalysis);
807+
const OperationSideEffects sideEffects(op, aliasAnalysis, domInfo);
787808
if (sideEffects.allocatesResource()) {
788809
LLVM_DEBUG({
789810
llvm::dbgs() << "Operation: " << op << "\n";
@@ -799,8 +820,8 @@ static bool canBeHoisted(Operation &op, LoopLikeOpInterface loop,
799820
// loop prevent hosting it.
800821
if ((sideEffects.readsFromResource() || sideEffects.writesToResource() ||
801822
sideEffects.freesResource()) &&
802-
hasConflictsInLoop(op, loop, willBeMoved, aliasAnalysis)) {
803-
LLVM_DEBUG(llvm::dbgs()
823+
hasConflictsInLoop(op, loop, willBeMoved, aliasAnalysis, domInfo)) {
824+
LLVM_DEBUG(llvm::dbgs().indent(2)
804825
<< "cannot be hoisted: found conflicting operation\n");
805826
return false;
806827
}
@@ -814,31 +835,30 @@ static bool canBeHoisted(Operation &op, LoopLikeOpInterface loop,
814835

815836
for (Region &region : op.getRegions()) {
816837
for (Operation &innerOp : region.getOps()) {
817-
if (!canBeHoisted(innerOp, loop, willBeMoved2, aliasAnalysis))
838+
if (!canBeHoisted(innerOp, loop, willBeMoved2, aliasAnalysis, domInfo))
818839
return false;
819840
willBeMoved2.insert(&innerOp);
820841
}
821842
}
822843

823-
LLVM_DEBUG(llvm::dbgs() << "can be hoisted: no conflicts found\n");
844+
LLVM_DEBUG(llvm::dbgs().indent(2) << "can be hoisted: no conflicts found\n");
824845

825846
return true;
826847
}
827848

828849
// Populate \p opsToMove with operations that can be hoisted out of the given
829850
// loop \p loop.
830-
static void
831-
collectHoistableOperations(LoopLikeOpInterface loop,
832-
const AliasAnalysis &aliasAnalysis,
833-
SmallVectorImpl<Operation *> &opsToMove) {
851+
static void collectHoistableOperations(
852+
LoopLikeOpInterface loop, const AliasAnalysis &aliasAnalysis,
853+
const DominanceInfo &domInfo, SmallVectorImpl<Operation *> &opsToMove) {
834854
// Do not use walk here, as we do not want to go into nested regions and
835855
// hoist operations from there. These regions might have semantics unknown
836856
// to this rewriting. If the nested regions are loops, they will have been
837857
// processed.
838858
SmallPtrSet<Operation *, 8> willBeMoved;
839859
for (Block &block : loop.getLoopBody()) {
840860
for (Operation &op : block.without_terminator()) {
841-
if (!canBeHoisted(op, loop, willBeMoved, aliasAnalysis))
861+
if (!canBeHoisted(op, loop, willBeMoved, aliasAnalysis, domInfo))
842862
continue;
843863
opsToMove.push_back(&op);
844864
willBeMoved.insert(&op);
@@ -847,13 +867,14 @@ collectHoistableOperations(LoopLikeOpInterface loop,
847867
}
848868

849869
static size_t moveLoopInvariantCode(LoopLikeOpInterface loop,
850-
const AliasAnalysis &aliasAnalysis) {
870+
const AliasAnalysis &aliasAnalysis,
871+
const DominanceInfo &domInfo) {
851872
Operation *loopOp = loop;
852873
if (!isa<scf::ForOp, scf::ParallelOp, AffineParallelOp, AffineForOp>(loopOp))
853874
return 0;
854875

855876
SmallVector<Operation *, 8> opsToMove;
856-
collectHoistableOperations(loop, aliasAnalysis, opsToMove);
877+
collectHoistableOperations(loop, aliasAnalysis, domInfo, opsToMove);
857878
if (opsToMove.empty())
858879
return 0;
859880

@@ -870,6 +891,7 @@ static size_t moveLoopInvariantCode(LoopLikeOpInterface loop,
870891

871892
void LICM::runOnOperation() {
872893
AliasAnalysis &aliasAnalysis = getAnalysis<AliasAnalysis>();
894+
DominanceInfo &domInfo = getAnalysis<DominanceInfo>();
873895

874896
[[maybe_unused]] auto getParentFunction = [](LoopLikeOpInterface loop) {
875897
Operation *parentOp = loop;
@@ -905,7 +927,7 @@ void LICM::runOnOperation() {
905927

906928
// Now use this pass to hoist more complex operations.
907929
{
908-
size_t OpHoisted = moveLoopInvariantCode(loop, aliasAnalysis);
930+
size_t OpHoisted = moveLoopInvariantCode(loop, aliasAnalysis, domInfo);
909931
numOpHoisted += OpHoisted;
910932

911933
LLVM_DEBUG({

polygeist/test/polygeist-opt/licm.mlir

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ func.func @affine_for_hoist2(%arg0: memref<?xf32>, %arg1: f32) {
229229
return
230230
}
231231

232-
// COM: Ensure reductions loops guards are correct.
232+
// COM: Ensure reductions loops guards are correct.
233233
func.func @affine_for_hoist3(%arg0: memref<?xi32>, %arg1: i32) -> (i32) {
234234
// CHECK: func.func @affine_for_hoist3(%arg0: memref<?xi32>, %arg1: i32) -> i32 {
235235
// CHECK-NEXT: %alloca = memref.alloca() : memref<1xi32>
@@ -255,6 +255,68 @@ func.func @affine_for_hoist3(%arg0: memref<?xi32>, %arg1: i32) -> (i32) {
255255
}
256256
return %sum : i32
257257
}
258+
259+
// COM: Ensure aliased store dominating load can be hoisted.
260+
func.func @affine_for_hoist4(%arg0: memref<?xi32>) {
261+
// CHECK: func.func @affine_for_hoist4(%arg0: memref<?xi32>) {
262+
// CHECK-NEXT: %alloca = memref.alloca() : memref<1xi32>
263+
// CHECK-NEXT: %c3_i32 = arith.constant 3 : i32
264+
// CHECK-NEXT: affine.if #set1() {
265+
// CHECK-NEXT: affine.store %c3_i32, %alloca[0] : memref<1xi32>
266+
// CHECK-NEXT: %0 = affine.load %alloca[0] : memref<1xi32>
267+
// CHECK-NEXT: affine.for %arg1 = 0 to 10 {
268+
// CHECK-NEXT: %1 = affine.load %arg0[0] : memref<?xi32>
269+
// CHECK-NEXT: %2 = arith.addi %1, %0 : i32
270+
// CHECK-NEXT: affine.store %2, %arg0[0] : memref<?xi32>
271+
// CHECK-NEXT: }
272+
// CHECK-NEXT: }
273+
274+
%alloca = memref.alloca() : memref<1xi32>
275+
%c3 = arith.constant 3 : i32
276+
affine.for %arg1 = 0 to 10 {
277+
// Store can be hoisted because it is the only reaching definition for the first load.
278+
// - the store dominates the aliased load and
279+
// - there is no other aliased store in the loop
280+
affine.store %c3, %alloca[0] : memref<1xi32>
281+
%c3_1 = affine.load %alloca[0] : memref<1xi32>
282+
%arr = affine.load %arg0[0] : memref<?xi32>
283+
%add = arith.addi %arr, %c3_1 : i32
284+
affine.store %add, %arg0[0] : memref<?xi32>
285+
}
286+
return
287+
}
288+
289+
// COM: Ensure aliased store after dominating load cannot be hoisted.
290+
func.func @affine_for_nohoist1(%arg0: memref<?xi32>) {
291+
// CHECK: func.func @affine_for_nohoist1(%arg0: memref<?xi32>) {
292+
// CHECK-NEXT: %alloca = memref.alloca() : memref<1xi32>
293+
// CHECK-DAG: %c3_i32 = arith.constant 3 : i32
294+
// CHECK-DAG: %c4_i32 = arith.constant 4 : i32
295+
// CHECK-NEXT: affine.store %c3_i32, %alloca[0] : memref<1xi32>
296+
// CHECK-NEXT: affine.for %arg1 = 0 to 10 {
297+
// CHECK-NEXT: %0 = affine.load %alloca[0] : memref<1xi32>
298+
// CHECK-NEXT: affine.store %c4_i32, %alloca[0] : memref<1xi32>
299+
// CHECK-NEXT: %1 = affine.load %arg0[0] : memref<?xi32>
300+
// CHECK-NEXT: %2 = arith.addi %1, %0 : i32
301+
// CHECK-NEXT: affine.store %2, %arg0[0] : memref<?xi32>
302+
// CHECK-NEXT: }
303+
304+
%alloca = memref.alloca() : memref<1xi32>
305+
%c3 = arith.constant 3 : i32
306+
%c4 = arith.constant 4 : i32
307+
affine.store %c3, %alloca[0] : memref<1xi32>
308+
affine.for %arg2 = 0 to 10 {
309+
// Cannot hoist the load because the loop has a store that can change the loaded result.
310+
%c3_1 = affine.load %alloca[0] : memref<1xi32>
311+
// Cannot hoist the store because it changes the value loaded by the previous operation,
312+
// (the store does not dominate the load %c3_1).
313+
affine.store %c4, %alloca[0] : memref<1xi32>
314+
%arr = affine.load %arg0[0] : memref<?xi32>
315+
%add = arith.addi %arr, %c3_1 : i32
316+
affine.store %add, %arg0[0] : memref<?xi32>
317+
}
318+
return
319+
}
258320
}
259321

260322
// -----

0 commit comments

Comments
 (0)