@@ -1068,12 +1068,12 @@ getUntiledProducerFromSliceSource(OpOperand *source,
1068
1068
return {dyn_cast<OpResult>(source->get ()), destinationIterArg};
1069
1069
}
1070
1070
1071
- // / Implementation of fusing producer of a single slice by computing the
1071
+ // / Basic implementation of fusing producer of a single slice by computing the
1072
1072
// / slice of the producer in-place.
1073
- std::optional<scf::SCFFuseProducerOfSliceResult>
1074
- mlir::scf::tileAndFuseProducerOfSlice (
1075
- RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1076
- MutableArrayRef<LoopLikeOpInterface> loops) {
1073
+ static std::optional<scf::SCFFuseProducerOfSliceResult>
1074
+ tileAndFuseProducerOfSliceImpl (RewriterBase &rewriter,
1075
+ tensor::ExtractSliceOp candidateSliceOp,
1076
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1077
1077
// 1. Get the producer of the source (potentially walking through
1078
1078
// `iter_args` of nested `scf.for`)
1079
1079
auto [fusableProducer, destinationInitArg] =
@@ -1185,6 +1185,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
1185
1185
tileAndFuseResult->tiledOps };
1186
1186
}
1187
1187
1188
+ // / Get the real producer from candidate ExtractSliceOp
1189
+ // /
1190
+ // / ```
1191
+ // / %0 = producer
1192
+ // / %1 = scf.for(%arg1 = %0)
1193
+ // / %2 = extract %arg1
1194
+ // / %3 = scf.for(%arg2 = %2)
1195
+ // / %4 = extract %args2
1196
+ // / ...
1197
+ // / ```
1198
+ // /
1199
+ // / @param candidateSliceOp: %4 = extract %args2
1200
+ // / @param backwardSlice: in-out parameter populated by backward extractSliceOps
1201
+ // / @return OpResult Producer : %0 = producer
1202
+ static FailureOr<OpResult> getRealProducerFromExtractSliceOp (
1203
+ Operation *candidateSliceOp,
1204
+ SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0 ,
1205
+ int maxDepth = 5 ) {
1206
+ if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
1207
+ return failure ();
1208
+ // control recursive time in avoid of stack overflow
1209
+ if (curDepth > maxDepth)
1210
+ return failure ();
1211
+
1212
+ auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
1213
+ backwardSlice.push_back (extractOp);
1214
+ Value rootSource = extractOp.getSourceMutable ().get ();
1215
+
1216
+ while (true ) {
1217
+ if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
1218
+ if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
1219
+ iterArg.getOwner ()->getParentOp ())) {
1220
+ rootSource = outerLoop.getTiedLoopInit (iterArg)->get ();
1221
+ continue ;
1222
+ }
1223
+ return failure ();
1224
+ } else if (auto sliceOp =
1225
+ rootSource.getDefiningOp <tensor::ExtractSliceOp>()) {
1226
+ // walk up loop to find larger candidate extractSliceOp
1227
+ return getRealProducerFromExtractSliceOp (sliceOp, backwardSlice,
1228
+ curDepth + 1 );
1229
+ }
1230
+ break ;
1231
+ }
1232
+ return dyn_cast<OpResult>(rootSource);
1233
+ }
1234
+
1235
+ // / Recursively find the outer nest loops of given loop(included) while the
1236
+ // / predict function succeed, sorted from outer to inner.
1237
+ // /
1238
+ // / @param loop: target loop, note that this loop will be also included. I.e.
1239
+ // / if no other nest loops were found, just return itself.
1240
+ // / @param pred: predict function, the termination condition of recursive
1241
+ // / process.
1242
+ // / @return Outer Nest Loops: nest loops outside given target loop(included).
1243
+ // /
1244
+ // / E.g.
1245
+ // /
1246
+ // / ```
1247
+ // / %0 = scf.for()
1248
+ // / %1 = scf.for()
1249
+ // / %2 = scf.for()
1250
+ // / ```
1251
+ // /
1252
+ // / If `%2 = scf.for` is given without specific prediction function, this
1253
+ // / function will return three nest loops: %0 + %1 + %2.
1254
+ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile (
1255
+ LoopLikeOpInterface loop,
1256
+ const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
1257
+ SmallVector<LoopLikeOpInterface> nestLoops = {loop};
1258
+ auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp ());
1259
+ while (outerLoop && succeeded (pred (outerLoop))) {
1260
+ nestLoops.push_back (outerLoop);
1261
+ outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp ());
1262
+ }
1263
+ // sorted from outer to inner
1264
+ return {nestLoops.rbegin (), nestLoops.rend ()};
1265
+ }
1266
+
1267
+ // / Enhanced version for basic implementation of fusing producer, which can deal
1268
+ // / with multi-level candidates. E.g.
1269
+ // /
1270
+ // / ```
1271
+ // / %0 = untiled_producer
1272
+ // / %1 = scf.for(%arg1 = %0)
1273
+ // / %2 = tensor.extract_slice %arg1
1274
+ // / %3 = scf.for(%arg2 = %2)
1275
+ // / %4 = tensor.extract_slice %args2
1276
+ // / %5 = tiled_consumer ins(%4)
1277
+ // / ```
1278
+ // /
1279
+ // / This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
1280
+ // / inner loop `%3 = scf.for`.
1281
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1282
+ mlir::scf::tileAndFuseProducerOfSlice (RewriterBase &rewriter,
1283
+ Operation *candidateSliceOp) {
1284
+ SmallVector<tensor::ExtractSliceOp> backwardSlice;
1285
+ if (failed (
1286
+ getRealProducerFromExtractSliceOp (candidateSliceOp, backwardSlice))) {
1287
+ return std::nullopt;
1288
+ }
1289
+
1290
+ std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
1291
+ // reverse from outer to inner
1292
+ std::reverse (backwardSlice.begin (), backwardSlice.end ());
1293
+ // multiple application of `tileAndFuseProducerOfSliceImpl`
1294
+ for (auto &&[index, sliceOp] : llvm::enumerate (backwardSlice)) {
1295
+ // get nest loops between next candidate sliceOp and tiled producer.
1296
+ auto whileProducerOutOfLoopBlock =
1297
+ [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
1298
+ if (fuseProducerResult) {
1299
+ Block &body = loop->getRegion (0 ).front ();
1300
+ if (fuseProducerResult->tiledAndFusedProducer .getDefiningOp ()
1301
+ ->getBlock () == &body)
1302
+ return failure ();
1303
+ }
1304
+ return success ();
1305
+ };
1306
+ SmallVector<LoopLikeOpInterface> outerLoops =
1307
+ getOuterNestLoopsWhile (sliceOp->getParentOfType <LoopLikeOpInterface>(),
1308
+ whileProducerOutOfLoopBlock);
1309
+ fuseProducerResult =
1310
+ tileAndFuseProducerOfSliceImpl (rewriter, sliceOp, outerLoops);
1311
+ if (!fuseProducerResult) {
1312
+ return std::nullopt;
1313
+ }
1314
+ }
1315
+ return fuseProducerResult;
1316
+ }
1317
+
1318
+ // / Implementation of fusing producer of a single slice by computing the
1319
+ // / slice of the producer in-place.
1320
+ std::optional<scf::SCFFuseProducerOfSliceResult>
1321
+ mlir::scf::tileAndFuseProducerOfSlice (
1322
+ RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
1323
+ MutableArrayRef<LoopLikeOpInterface> loops) {
1324
+ return tileAndFuseProducerOfSliceImpl (rewriter, candidateSliceOp, loops);
1325
+ }
1326
+
1188
1327
// / Reconstruct the fused producer from within the tiled-and-fused code.
1189
1328
LogicalResult mlir::scf::yieldReplacementForFusedProducer (
1190
1329
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
0 commit comments