Skip to content

Commit eaa32d2

Browse files
authored
[mlir] fix affine-loop-fusion crash (llvm#76351)
If `user` not lies in `Region` `findAncestorOpInRegion` will return `nullptr`. Fixes llvm#76281.
1 parent 6c87f46 commit eaa32d2

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ static bool isEscapingMemref(Value memref, Block *block) {
205205
// (e.g., call ops, alias creating ops, etc.).
206206
return llvm::any_of(memref.getUsers(), [&](Operation *user) {
207207
// Ignore users outside of `block`.
208-
if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block)
208+
Operation *ancestorOp = block->getParent()->findAncestorOpInRegion(*user);
209+
if (!ancestorOp)
210+
return true;
211+
if (ancestorOp->getBlock() != block)
209212
return false;
210213
return !isa<AffineMapAccessInterface>(*user);
211214
});

mlir/test/Dialect/Affine/loop-fusion.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,5 +1541,37 @@ func.func @should_fuse_and_preserve_dep_on_constant() {
15411541
return
15421542
}
15431543

1544+
// -----
1545+
1546+
// CHECK-LABEL: @producer_consumer_with_outmost_user
1547+
func.func @producer_consumer_with_outmost_user(%arg0 : f16) {
1548+
%c0 = arith.constant 0 : index
1549+
%src = memref.alloc() : memref<f16, 1>
1550+
%dst = memref.alloc() : memref<f16>
1551+
%tag = memref.alloc() : memref<1xi32>
1552+
affine.for %arg1 = 4 to 6 {
1553+
affine.for %arg2 = 0 to 1 {
1554+
%0 = arith.addf %arg0, %arg0 : f16
1555+
affine.store %0, %src[] : memref<f16, 1>
1556+
}
1557+
affine.for %arg3 = 0 to 1 {
1558+
%0 = affine.load %src[] : memref<f16, 1>
1559+
}
1560+
}
1561+
affine.dma_start %src[], %dst[], %tag[%c0], %c0 : memref<f16, 1>, memref<f16>, memref<1xi32>
1562+
// CHECK: %[[CST_INDEX:.*]] = arith.constant 0 : index
1563+
// CHECK: %[[DMA_SRC:.*]] = memref.alloc() : memref<f16, 1>
1564+
// CHECK: %[[DMA_DST:.*]] = memref.alloc() : memref<f16>
1565+
// CHECK: %[[DMA_TAG:.*]] = memref.alloc() : memref<1xi32>
1566+
// CHECK: affine.for %arg1 = 4 to 6
1567+
// CHECK-NEXT: affine.for %arg2 = 0 to 1
1568+
// CHECK-NEXT: %[[RESULT_ADD:.*]] = arith.addf %arg0, %arg0 : f16
1569+
// CHECK-NEXT: affine.store %[[RESULT_ADD]], %[[DMA_SRC]][] : memref<f16, 1>
1570+
// CHECK-NEXT: affine.load %[[DMA_SRC]][] : memref<f16, 1>
1571+
// CHECK: affine.dma_start %[[DMA_SRC]][], %[[DMA_DST]][], %[[DMA_TAG]][%[[CST_INDEX]]], %[[CST_INDEX]] : memref<f16, 1>, memref<f16>, memref<1xi32>
1572+
// CHECK-NEXT: return
1573+
return
1574+
}
1575+
15441576
// Add further tests in mlir/test/Transforms/loop-fusion-4.mlir
15451577

0 commit comments

Comments
 (0)