Skip to content

Commit 2f2a7b4

Browse files
committed
fix whileProducerOutOfLoopBlock
1 parent bc7f33b commit 2f2a7b4

File tree

2 files changed

+79
-24
lines changed

2 files changed

+79
-24
lines changed

lib/gc/Transforms/TilingUsingInterfaceX.cpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,21 @@ struct ErasedOpListener : public RewriterBase::Listener {
278278
bool isErased(Operation *op) { return erased.count(op); }
279279
};
280280

281+
/// Check if it is the ForOp that yield the result of inner loop
282+
static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
283+
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
284+
for (auto &&[index, op] :
285+
llvm::enumerate(forOp.getBody()->getOperations())) {
286+
// If the orderIndex of inner loop is the last second one before the
287+
// yieldOp of ForOp, the given loop must yield the result of inner loop.
288+
if (isa<LoopLikeOpInterface>(op)) {
289+
return success((index + 2) == forOp.getBody()->getOperations().size());
290+
}
291+
}
292+
}
293+
return failure();
294+
}
295+
281296
/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
282297
/// multi-level `extractSliceOp`. E.g.
283298
///
@@ -293,7 +308,9 @@ std::optional<scf::SCFFuseProducerOfSliceResult>
293308
mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
294309
Operation *candidateSliceOp) {
295310
SmallVector<tensor::ExtractSliceOp> backwardSlice;
296-
if (failed(getRealProducerOfExtractSliceOp(candidateSliceOp, backwardSlice)))
311+
FailureOr<OpResult> realProducer =
312+
getRealProducerOfExtractSliceOp(candidateSliceOp, backwardSlice);
313+
if (failed(realProducer))
297314
return std::nullopt;
298315

299316
std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
@@ -303,14 +320,18 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
303320
for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
304321
// get nest loops between next candidate sliceOp and tiled producer.
305322
auto whileProducerOutOfLoopBlock =
306-
[&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
307-
if (fuseProducerResult) {
308-
Block &body = loop->getRegion(0).front();
309-
if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
310-
->getBlock() == &body)
311-
return failure();
312-
}
313-
return success();
323+
[&fuseProducerResult,
324+
&realProducer](LoopLikeOpInterface loop) -> LogicalResult {
325+
// ensure that all surrounding outer loops are just yielding the result of
326+
// the inner loops.
327+
if (failed(isForOpYieldResultOfInnerLoop(loop)))
328+
return failure();
329+
Operation *originalOp =
330+
fuseProducerResult
331+
? fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
332+
: realProducer->getDefiningOp();
333+
Block &body = loop->getRegion(0).front();
334+
return success(originalOp->getBlock() != &body);
314335
};
315336
SmallVector<LoopLikeOpInterface> outerLoops =
316337
getOuterNestLoopsWhile(sliceOp->getParentOfType<LoopLikeOpInterface>(),
@@ -515,21 +536,6 @@ static FailureOr<OpOperand *> getConsumerFromUses(Value val,
515536
return operand;
516537
}
517538

518-
/// Check if it is the ForOp that yield the result of inner loop
519-
static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
520-
if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
521-
for (auto &&[index, op] :
522-
llvm::enumerate(forOp.getBody()->getOperations())) {
523-
// If the orderIndex of inner loop is the last second one before the
524-
// yieldOp of ForOp, the given loop must yield the result of inner loop.
525-
if (isa<LoopLikeOpInterface>(op)) {
526-
return success((index + 2) == forOp.getBody()->getOperations().size());
527-
}
528-
}
529-
}
530-
return failure();
531-
}
532-
533539
/// Fetch the untiled consumer of a scf.for's result which is yielded by a
534540
/// tensor.insert_slice. This function makes the following assumptions that
535541
/// tensor.insert_slice has scf.yield as its only user.

test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,4 +535,53 @@ module {
535535
/// CHECK: return %[[FINAL_RESULT]]#1, %[[FINAL_RESULT]]#0
536536
return %0, %1 : tensor<16x32x32xf32>, tensor<16x32xf32>
537537
}
538+
}
539+
540+
// -----
541+
542+
#map = affine_map<(d0) -> (d0 * 128)>
543+
module {
544+
/// CHECK-LABEL: @fuse_tiled_producer
545+
func.func @fuse_tiled_producer(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>) -> tensor<256x256xf32> {
546+
%c0 = arith.constant 0 : index
547+
%c64 = arith.constant 64 : index
548+
%c128 = arith.constant 128 : index
549+
%cst = arith.constant 0.000000e+00 : f32
550+
%dest0 = tensor.empty() : tensor<256x256xf32>
551+
/// CHECK: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (2, 2)
552+
%1 = scf.forall (%arg4, %arg5) in (2, 2) shared_outs(%arg6 = %dest0) -> tensor<256x256xf32> {
553+
%iv0 = affine.apply #map(%arg4)
554+
%iv1 = affine.apply #map(%arg5)
555+
%extracted_slice_1 = tensor.extract_slice %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
556+
%dest1 = linalg.fill ins(%cst : f32) outs(%extracted_slice_1 : tensor<128x128xf32>) -> tensor<128x128xf32>
557+
%extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
558+
%extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
559+
/// CHECK: scf.for
560+
/// CHECK: scf.for
561+
%2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %dest1) -> (tensor<128x128xf32>) {
562+
%3 = scf.for %arg9 = %c0 to %c128 step %c64 iter_args(%arg10 = %arg8) -> (tensor<128x128xf32>) {
563+
%extracted_slice_4 = tensor.extract_slice %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
564+
%extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
565+
%extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg9] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
566+
/// CHECK: %[[FILL_OUT:.*]] = linalg.fill
567+
/// CHECK: %[[MATMUL_OUT:.*]] = linalg.matmul
568+
/// CHECK: %[[EXP_OUT:.*]] = linalg.exp
569+
%4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
570+
%insert_slice = tensor.insert_slice %4 into %arg10[%arg7, %arg9] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
571+
/// CHECK: scf.yield {{.*}}, {{.*}} : tensor<128x128xf32>, tensor<128x128xf32>
572+
scf.yield %insert_slice : tensor<128x128xf32>
573+
}
574+
/// CHECK: scf.yield {{.*}}, {{.*}} : tensor<128x128xf32>, tensor<128x128xf32>
575+
scf.yield %3 : tensor<128x128xf32>
576+
}
577+
scf.forall.in_parallel {
578+
/// CHECK: tensor.parallel_insert_slice
579+
/// CHECK: tensor.parallel_insert_slice
580+
tensor.parallel_insert_slice %2 into %arg6[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
581+
}
582+
}
583+
%2 = linalg.exp ins(%1 : tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
584+
/// CHECK: return %[[FINAL_RESULT]]#1
585+
return %2 : tensor<256x256xf32>
586+
}
538587
}

0 commit comments

Comments
 (0)