Skip to content

Commit 2ed7c3f

Browse files
committed
[MLIR][SCF] Enable better bufferization for TileConsumerAndFuseProducersUsingSCFForOp
Replace iterators of the outermost loop with region arguments of the innermost one. The changes avoid later `bufferization` passes to insert allocation within the body of the innermost loop. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D130083
1 parent 2955192 commit 2ed7c3f

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,23 @@ static Optional<OpResult> getFusableProducer(Value v) {
355355
return v.cast<OpResult>();
356356
}
357357

358+
// Replace iter args of the outer most loop with region args of the inner most
359+
// one.
360+
static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor,
361+
PatternRewriter &rewriter) {
362+
assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() &&
363+
"expect same number of iter args");
364+
Block *block = &(*innerFor.getRegion().begin());
365+
for (auto it :
366+
llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) {
367+
Value source = std::get<0>(it);
368+
Value target = std::get<1>(it);
369+
source.replaceUsesWithIf(target, [&](OpOperand &use) {
370+
return use.getOwner()->getBlock() == block;
371+
});
372+
}
373+
}
374+
358375
FailureOr<scf::SCFTileAndFuseResult>
359376
scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
360377
TilingInterface op, PatternRewriter &rewriter) const {
@@ -470,5 +487,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
470487
}
471488
}
472489
}
490+
replaceIterArgs(tileAndFuseResult.loops.front(),
491+
tileAndFuseResult.loops.back(), rewriter);
473492
return tileAndFuseResult;
474493
}

mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func.func @gemm_fill_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) ->
2323
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
2424
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
2525
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
26-
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
26+
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
2727
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
2828
// CHECK-SAME: outs(%[[INIT_TILE]] :
2929
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
@@ -68,7 +68,7 @@ func.func @gemm_generic_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
6868
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
6969
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0]
7070
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]]
71-
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]]
71+
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]]
7272
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
7373
// CHECK-SAME: outs(%[[INIT_TILE]] :
7474
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
@@ -123,7 +123,7 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor<?x?xf32>, %rhs0 : tensor<?x?xf32>, %r
123123
// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
124124
// CHECK-SAME: outs(%[[FILL0_TILE]] :
125125
// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0]
126-
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0]
126+
// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG]][%[[IV]], 0]
127127
// CHECK: %[[FILL1_TILE:.+]] = linalg.fill
128128
// CHECK-SAME: outs(%[[INIT1_TILE]] :
129129
// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul
@@ -218,7 +218,7 @@ func.func @interchange_matmul_fusion(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?
218218
// CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]])
219219
// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0]
220220
// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]]
221-
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]]
221+
// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]]
222222
// CHECK: %[[FILL_TILE:.+]] = linalg.fill
223223
// CHECK-SAME: outs(%[[INIT_TILE]] :
224224
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul

0 commit comments

Comments
 (0)