Skip to content

Commit 3af1c48

Browse files
author
Mahesh Ravishankar
committed
Changes to SCFFuseProducerOfSliceResult to also return the operations created during fusion.
This is follow up to https://reviews.llvm.org/D145133 that allows propogating information about ops that are fused back to the caller. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D146254
1 parent 091422a commit 3af1c48

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ struct SCFTileAndFuseOptions {
9696
struct SCFFuseProducerOfSliceResult {
9797
OpResult origProducer; // Original untiled producer.
9898
Value tiledAndFusedProducer; // Tile and fused producer value.
99+
SmallVector<Operation *> tiledOps;
99100
};
100101
std::optional<SCFFuseProducerOfSliceResult>
101102
tileAndFuseProducerOfSlice(RewriterBase &rewriter,

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,15 +604,17 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
604604
}
605605
}
606606
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
607-
tileAndFuseResult->tiledValues[0]};
607+
tileAndFuseResult->tiledValues[0],
608+
tileAndFuseResult->tiledOps};
608609
}
609610

610611
/// Reconstruct the fused producer from within the tiled-and-fused code.
611612
void mlir::scf::yieldReplacementForFusedProducer(
612613
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
613614
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
614615
MutableArrayRef<scf::ForOp> loops) {
615-
auto [fusableProducer, fusedProducerValue] = fusedProducerInfo;
616+
auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
617+
fusedProducerInfo;
616618
SmallVector<Value> initValues;
617619
FailureOr<Value> initValue = tensor::getOrCreateDestination(
618620
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
@@ -623,8 +625,11 @@ void mlir::scf::yieldReplacementForFusedProducer(
623625
yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
624626
resultOffsets, resultSizes, loops);
625627
}
626-
if (auto dstStyleProducer =
627-
fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) {
628+
for (auto tileAndFusedOp : tileAndFusedOps) {
629+
auto dstStyleProducer =
630+
dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
631+
if (!dstStyleProducer)
632+
continue;
628633
Value dstValue =
629634
dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
630635
->get();

0 commit comments

Comments
 (0)