Skip to content

Commit c42a262

Browse files
[MLIR] Bug Fix: affine.prefetch replaceAffineOp invoked during canonicalization (#88346)
Signed-off-by: Alexandre Eichenberger <[email protected]>
1 parent 7aa3716 commit c42a262

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,9 +1487,8 @@ void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
14871487
PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
14881488
ArrayRef<Value> mapOperands) const {
14891489
rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
1490-
prefetch, prefetch.getMemref(), map, mapOperands,
1491-
prefetch.getLocalityHint(), prefetch.getIsWrite(),
1492-
prefetch.getIsDataCache());
1490+
prefetch, prefetch.getMemref(), map, mapOperands, prefetch.getIsWrite(),
1491+
prefetch.getLocalityHint(), prefetch.getIsDataCache());
14931492
}
14941493
template <>
14951494
void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(

mlir/test/Dialect/Affine/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,3 +1452,17 @@ func.func @mod_of_mod(%lb: index, %ub: index, %step: index) -> (index, index) {
14521452
%1 = affine.apply affine_map<()[s0, s1, s2] -> ((s0 - ((s0 - s2) mod s1) - s2) mod s1)> ()[%ub, %step, %lb]
14531453
return %0, %1 : index, index
14541454
}
1455+
1456+
// -----
1457+
1458+
// CHECK-LABEL: func.func @prefetch_canonicalize
1459+
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<512xf32>) {
1460+
func.func @prefetch_canonicalize(%arg0: memref<512xf32>) -> () {
1461+
// CHECK: affine.for [[I_0_:%.+]] = 0 to 8 {
1462+
affine.for %arg3 = 0 to 8 {
1463+
%1 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%arg3]
1464+
// CHECK: affine.prefetch [[PARAM_0_]][symbol([[I_0_]]) * 64], read, locality<3>, data : memref<512xf32>
1465+
affine.prefetch %arg0[%1], read, locality<3>, data : memref<512xf32>
1466+
}
1467+
return
1468+
}

0 commit comments

Comments
 (0)