Skip to content

Commit fd48251

Browse files
[mlir][TilingInterface] Add scf::tileUsingSCFForallOp method to tile using the interface to generate scf::forall.
Similar to `scf::tileUsingSCFForOp` that is a method that tiles operations that implement the `TilingInterface`, using `scf.for` operations, this method introduces tiling of operations using `scf.forall`. Most of this implementation is derived from `linalg::tileToForallOp` method. Eventually that method will either be deprecated or moved to use the method introduced here.
1 parent f2517cb commit fd48251

File tree

4 files changed

+256
-0
lines changed

4 files changed

+256
-0
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ struct SCFTilingOptions {
5151
interchangeVector = llvm::to_vector(interchange);
5252
return *this;
5353
}
54+
55+
/// Specify mapping of loops to devices. This is only respected when the loop
56+
/// constructs support such a mapping (like `scf.forall`). Will be ignored
57+
/// when using loop constructs that dont support such a mapping (like
58+
/// `scf.for`)
59+
SmallVector<Attribute> mappingVector = {};
60+
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
61+
mappingVector = llvm::to_vector(
62+
llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
63+
return *this;
64+
}
5465
};
5566

5667
/// Transformation information returned after tiling.
@@ -82,6 +93,12 @@ struct SCFTileAndFuseOptions {
8293
}
8394
};
8495

96+
/// Method to tile and op that implements the `TilingInterface` using
97+
/// `scf.forall`.
98+
FailureOr<SCFTilingResult>
99+
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
100+
const SCFTilingOptions &options);
101+
85102
/// Fuse the producer of the source of `candidateSliceOp` by computing the
86103
/// required slice of the producer in-place. Note that the method
87104
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,24 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
122122
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
123123
}
124124

125+
/// Clones the operation and updates the destination if the operation
126+
/// implements the `DestinationStyleOpInterface`.
127+
static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
128+
Operation *op,
129+
ValueRange newDestArgs) {
130+
Operation *clonedOp = rewriter.clone(*op);
131+
if (auto destinationStyleOp =
132+
dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
133+
// Note that this is assuming that
134+
auto [start, end] = destinationStyleOp.getDpsInitsPositionRange();
135+
assert((end - start == newDestArgs.size()) &&
136+
"expected as many new destination args as number of inits of the "
137+
"operation");
138+
clonedOp->setOperands(start, end - start, newDestArgs);
139+
}
140+
return clonedOp;
141+
}
142+
125143
/// Generate an empty loop nest that represents the tiled loop nest shell.
126144
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
127145
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
@@ -728,6 +746,121 @@ mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
728746
getAsOperations(forLoops), replacements};
729747
}
730748

749+
//===----------------------------------------------------------------------===//
750+
// tileUsingSCFForAllOp implementation.
751+
//===----------------------------------------------------------------------===//
752+
753+
FailureOr<scf::SCFTilingResult>
754+
mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
755+
const scf::SCFTilingOptions &options) {
756+
Location loc = op->getLoc();
757+
OpBuilder::InsertionGuard g(rewriter);
758+
759+
// 1. Get the range of loops that are represented by the operation.
760+
SmallVector<Range> loopRanges = op.getIterationDomain(rewriter);
761+
if (loopRanges.empty())
762+
return op->emitOpError("expected non-empty loop ranges");
763+
auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
764+
if (llvm::any_of(loopRanges, hasStrideOne))
765+
return op->emitOpError("only stride-1 supported atm");
766+
767+
// 2. Get the tile sizes. If tile size is 0, it is not tiled and distributed.
768+
// To make it easier, pad the tile sizes to loopRanges.size with value 0.
769+
SmallVector<OpFoldResult> tileSizeVector =
770+
options.tileSizeComputationFunction(rewriter, op);
771+
tileSizeVector.resize(loopRanges.size(), rewriter.getIndexAttr(0));
772+
773+
// 3. Build the offsets, sizes and steps for the tile and distributed loops.
774+
SmallVector<OpFoldResult> lbs, ubs, steps;
775+
for (auto [index, tileSize, loopRange] :
776+
llvm::enumerate(tileSizeVector, loopRanges)) {
777+
if (isConstantIntValue(tileSize, 0))
778+
continue;
779+
lbs.push_back(loopRange.offset);
780+
ubs.push_back(loopRange.size);
781+
steps.push_back(tileSize);
782+
}
783+
784+
// 4. Gather destination tensors.
785+
SmallVector<Value> dest;
786+
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
787+
return op->emitOpError("failed to get destination tensors");
788+
789+
// 5. Build the device mapping attribute;
790+
std::optional<ArrayAttr> mappingAttr;
791+
if (!options.mappingVector.empty()) {
792+
mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
793+
}
794+
795+
// 6. Create the ForallOp. We don't use the lambda body-builder
796+
// version because we require the use of RewriterBase in the body, so we
797+
// manually move the insertion point to the body below.
798+
auto forallOp =
799+
rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, dest, mappingAttr);
800+
801+
// 7. Get the tile offset and sizes.
802+
rewriter.setInsertionPoint(forallOp.getTerminator());
803+
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
804+
tiledOffsets.reserve(loopRanges.size());
805+
tiledSizes.reserve(loopRanges.size());
806+
ValueRange ivs = forallOp.getInductionVars();
807+
{
808+
int materializedLoopNum = 0;
809+
for (auto [index, tileSize, loopRange] :
810+
llvm::enumerate(tileSizeVector, loopRanges)) {
811+
if (isConstantIntValue(tileSize, 0)) {
812+
tiledOffsets.push_back(loopRange.offset);
813+
tiledSizes.push_back(loopRange.size);
814+
continue;
815+
}
816+
Value iv = ivs[materializedLoopNum++];
817+
tiledOffsets.push_back(iv);
818+
tiledSizes.push_back(
819+
getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
820+
}
821+
}
822+
823+
// 8. Tile the operation. Clone the operation to allow fix up of destination
824+
// operands
825+
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
826+
Operation *clonedOp =
827+
cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
828+
FailureOr<TilingResult> tilingResult =
829+
cast<TilingInterface>(clonedOp).getTiledImplementation(
830+
rewriter, tiledOffsets, tiledSizes);
831+
if (failed(tilingResult))
832+
return clonedOp->emitError("Failed to tile op: ");
833+
rewriter.eraseOp(clonedOp);
834+
835+
// 9. Parallel insert back into the result tensor.
836+
for (auto [index, tiledValue, destBBArg] :
837+
llvm::enumerate(tilingResult->tiledValues, destBbArgs)) {
838+
// 9.a. Partial subset information is inserted just before the terminator.
839+
rewriter.setInsertionPoint(forallOp.getTerminator());
840+
841+
SmallVector<OpFoldResult> resultOffsets, resultSizes;
842+
if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
843+
tiledSizes, resultOffsets,
844+
resultSizes)))
845+
return op->emitOpError("output offsets couldn't be calculated");
846+
SmallVector<OpFoldResult> strides(resultSizes.size(),
847+
rewriter.getIndexAttr(1));
848+
849+
// 5.b. Parallel insertions are inserted at the end of the combining
850+
// terminator.
851+
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
852+
rewriter.create<tensor::ParallelInsertSliceOp>(
853+
loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
854+
}
855+
856+
// 10. Return the tiling result;
857+
return scf::SCFTilingResult{
858+
tilingResult->tiledOps,
859+
{forallOp.getOperation()},
860+
llvm::to_vector(llvm::map_range(forallOp.getResults(),
861+
[](auto val) -> Value { return val; }))};
862+
}
863+
731864
//===----------------------------------------------------------------------===//
732865
// lowerToLoopsUsingSCFForOp implementation.
733866
//===----------------------------------------------------------------------===//
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s
2+
3+
func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
4+
%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32> {
5+
%0 = linalg.matmul {__internal_transform__ = "simple_gemm"}
6+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
7+
outs(%arg2 : tensor<?x?xf32>) -> tensor<?x?xf32>
8+
return %0 : tensor<?x?xf32>
9+
}
10+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
11+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
12+
// CHECK: func.func @simple_matmul(
13+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
14+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
15+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
16+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
17+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
18+
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
19+
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
20+
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
21+
// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) =
22+
// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]])
23+
// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]]
24+
// CHECK: %[[TS_X:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[N]]]
25+
// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]]
26+
// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1]
27+
// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]]
28+
// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1]
29+
// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]]
30+
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
31+
// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul
32+
// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] :
33+
// CHECK-SAME: outs(%[[INIT_TILE]] :
34+
// CHECK: scf.forall.in_parallel {
35+
// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
36+
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
37+
// CHECK: return %[[RESULT]]

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,51 @@ struct TestTileUsingSCFForOp
186186
TransformationFilter filter;
187187
};
188188

189+
/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using
190+
/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles)
191+
/// while using a `filter` to avoid recursive application.
192+
struct TestTileUsingSCFForallOp
193+
: public OpInterfaceRewritePattern<TilingInterface> {
194+
TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options,
195+
TransformationFilter filter = TransformationFilter(),
196+
PatternBenefit benefit = 1)
197+
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
198+
options(std::move(options)), filter(std::move(filter)) {}
199+
200+
/// Construct a generic pattern applied to `opName`.
201+
TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context,
202+
scf::SCFTilingOptions options,
203+
TransformationFilter filter = TransformationFilter(),
204+
PatternBenefit benefit = 1)
205+
: OpInterfaceRewritePattern<TilingInterface>(context, benefit),
206+
options(std::move(options)), filter(std::move(filter)) {}
207+
208+
LogicalResult matchAndRewrite(TilingInterface op,
209+
PatternRewriter &rewriter) const override {
210+
if (failed(filter.checkAndNotify(rewriter, op)))
211+
return failure();
212+
213+
FailureOr<scf::SCFTilingResult> tilingResult =
214+
scf::tileUsingSCFForallOp(rewriter, op, options);
215+
if (failed(tilingResult))
216+
return rewriter.notifyMatchFailure(op, "failed to tile operation");
217+
218+
if (op->getNumResults()) {
219+
rewriter.replaceOp(op, tilingResult->replacements);
220+
} else {
221+
rewriter.eraseOp(op);
222+
}
223+
224+
for (auto *tiledOp : tilingResult->tiledOps)
225+
filter.replaceTransformationFilter(rewriter, tiledOp);
226+
return success();
227+
}
228+
229+
private:
230+
scf::SCFTilingOptions options;
231+
TransformationFilter filter;
232+
};
233+
189234
/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern
190235
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
191236
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
@@ -415,6 +460,12 @@ struct TestTilingInterfacePass
415460
"Test tiling using TilingInterface with scf.for operations"),
416461
llvm::cl::init(false)};
417462

463+
Option<bool> testTilingForAll{
464+
*this, "tile-using-scf-forall",
465+
llvm::cl::desc(
466+
"Test tiling using TilingInterface with scf.forall operations"),
467+
llvm::cl::init(false)};
468+
418469
Option<bool> testTileConsumerFuseAndYieldProducer{
419470
*this, "tile-consumer-fuse-and-yield-producer-using-scf-for",
420471
llvm::cl::desc(
@@ -455,6 +506,20 @@ static void addPatternForTiling(MLIRContext *context,
455506
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
456507
}
457508

509+
static void addPatternForTilingUsingForall(MLIRContext *context,
510+
RewritePatternSet &patterns,
511+
StringRef filterName,
512+
ArrayRef<int64_t> tileSizes,
513+
ArrayRef<int64_t> interchange = {}) {
514+
scf::SCFTilingOptions tilingOptions;
515+
SmallVector<OpFoldResult> tileSizesOfr =
516+
getAsIndexOpFoldResult(context, tileSizes);
517+
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
518+
TransformationFilter filter(StringAttr::get(context, filterName),
519+
StringAttr::get(context, "tiled"));
520+
patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
521+
}
522+
458523
static void addPatternForTileFuseAndYield(MLIRContext *context,
459524
RewritePatternSet &patterns,
460525
StringRef filterName,
@@ -514,6 +579,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
514579
addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
515580
return;
516581
}
582+
if (testTilingForAll) {
583+
addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
584+
return;
585+
}
517586
if (testTileConsumerAndFuseProducer) {
518587
// 1. Tile and fuse of gemm with fill producer and bias-add consumer.
519588
addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});

0 commit comments

Comments
 (0)