Skip to content

Commit 55f9518

Browse files
Address comments.
1 parent d1f1103 commit 55f9518

File tree

4 files changed

+30
-27
lines changed

4 files changed

+30
-27
lines changed

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ struct SCFTilingOptions {
5858
/// `scf.for`)
5959
SmallVector<Attribute> mappingVector = {};
6060
SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
61-
mappingVector = llvm::to_vector(
62-
llvm::map_range(mapping, [](auto attr) -> Attribute { return attr; }));
61+
mappingVector = llvm::map_to_vector(
62+
mapping, [](auto attr) -> Attribute { return attr; });
6363
return *this;
6464
}
6565
};
@@ -93,7 +93,7 @@ struct SCFTileAndFuseOptions {
9393
}
9494
};
9595

96-
/// Method to tile and op that implements the `TilingInterface` using
96+
/// Method to tile an op that implements the `TilingInterface` using
9797
/// `scf.forall`.
9898
FailureOr<SCFTilingResult>
9999
tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -767,8 +767,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
767767

768768
// 3. Build the offsets, sizes and steps for the tile and distributed loops.
769769
SmallVector<OpFoldResult> lbs, ubs, steps;
770-
for (auto [index, tileSize, loopRange] :
771-
llvm::enumerate(tileSizeVector, loopRanges)) {
770+
for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
772771
if (isConstantIntValue(tileSize, 0))
773772
continue;
774773
lbs.push_back(loopRange.offset);
@@ -781,7 +780,7 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
781780
if (failed(tensor::getOrCreateDestinations(rewriter, loc, op, dest)))
782781
return op->emitOpError("failed to get destination tensors");
783782

784-
// 5. Build the device mapping attribute;
783+
// 5. Build the device mapping attribute.
785784
std::optional<ArrayAttr> mappingAttr;
786785
if (!options.mappingVector.empty()) {
787786
mappingAttr = rewriter.getArrayAttr(ArrayRef(options.mappingVector));
@@ -796,13 +795,10 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
796795
// 7. Get the tile offset and sizes.
797796
rewriter.setInsertionPoint(forallOp.getTerminator());
798797
SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
799-
tiledOffsets.reserve(loopRanges.size());
800-
tiledSizes.reserve(loopRanges.size());
801798
ValueRange ivs = forallOp.getInductionVars();
802799
{
803800
int materializedLoopNum = 0;
804-
for (auto [index, tileSize, loopRange] :
805-
llvm::enumerate(tileSizeVector, loopRanges)) {
801+
for (auto [tileSize, loopRange] : llvm::zip(tileSizeVector, loopRanges)) {
806802
if (isConstantIntValue(tileSize, 0)) {
807803
tiledOffsets.push_back(loopRange.offset);
808804
tiledSizes.push_back(loopRange.size);
@@ -816,15 +812,15 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
816812
}
817813

818814
// 8. Tile the operation. Clone the operation to allow fix up of destination
819-
// operands
815+
// operands.
820816
ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
821817
Operation *clonedOp =
822818
cloneOpAndUpdateDestinationArgs(rewriter, op, destBbArgs);
823819
FailureOr<TilingResult> tilingResult =
824820
cast<TilingInterface>(clonedOp).getTiledImplementation(
825821
rewriter, tiledOffsets, tiledSizes);
826822
if (failed(tilingResult))
827-
return clonedOp->emitError("Failed to tile op: ");
823+
return clonedOp->emitError("failed to tile op: ");
828824
rewriter.eraseOp(clonedOp);
829825

830826
// 9. Parallel insert back into the result tensor.
@@ -836,24 +832,25 @@ mlir::scf::tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
836832
SmallVector<OpFoldResult> resultOffsets, resultSizes;
837833
if (failed(op.getResultTilePosition(rewriter, index, tiledOffsets,
838834
tiledSizes, resultOffsets,
839-
resultSizes)))
835+
resultSizes))) {
840836
return op->emitOpError("output offsets couldn't be calculated");
837+
}
838+
841839
SmallVector<OpFoldResult> strides(resultSizes.size(),
842840
rewriter.getIndexAttr(1));
843-
844-
// 5.b. Parallel insertions are inserted at the end of the combining
841+
// 9.b. Parallel insertions are inserted at the end of the combining
845842
// terminator.
846843
rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
847844
rewriter.create<tensor::ParallelInsertSliceOp>(
848845
loc, tiledValue, destBBArg, resultOffsets, resultSizes, strides);
849846
}
850847

851-
// 10. Return the tiling result;
848+
// 10. Return the tiling result.
852849
return scf::SCFTilingResult{
853850
tilingResult->tiledOps,
854851
{forallOp.getOperation()},
855-
llvm::to_vector(llvm::map_range(forallOp.getResults(),
856-
[](auto val) -> Value { return val; }))};
852+
llvm::map_to_vector(forallOp.getResults(),
853+
[](auto val) -> Value { return val; })};
857854
}
858855

859856
//===----------------------------------------------------------------------===//

mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ func.func @simple_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
3434
// CHECK: scf.forall.in_parallel {
3535
// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]]
3636
// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], %[[TS_X]]] [1, 1]
37+
// CHECK: mapping = [#gpu.block<y>, #gpu.block<x>]
3738
// CHECK: return %[[RESULT]]
3839

3940
// -----

mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1818
#include "mlir/Dialect/Func/IR/FuncOps.h"
19+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1920
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2021
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
2122
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
@@ -443,9 +444,9 @@ struct TestTilingInterfacePass
443444
TestTilingInterfacePass(const TestTilingInterfacePass &pass)
444445
: PassWrapper(pass) {}
445446
void getDependentDialects(DialectRegistry &registry) const override {
446-
registry.insert<affine::AffineDialect, linalg::LinalgDialect,
447-
memref::MemRefDialect, scf::SCFDialect,
448-
tensor::TensorDialect>();
447+
registry.insert<affine::AffineDialect, gpu::GPUDialect,
448+
linalg::LinalgDialect, memref::MemRefDialect,
449+
scf::SCFDialect, tensor::TensorDialect>();
449450
linalg::registerTilingInterfaceExternalModels(registry);
450451
tensor::registerTilingInterfaceExternalModels(registry);
451452
}
@@ -506,15 +507,16 @@ static void addPatternForTiling(MLIRContext *context,
506507
patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
507508
}
508509

509-
static void addPatternForTilingUsingForall(MLIRContext *context,
510-
RewritePatternSet &patterns,
511-
StringRef filterName,
512-
ArrayRef<int64_t> tileSizes,
513-
ArrayRef<int64_t> interchange = {}) {
510+
static void addPatternForTilingUsingForall(
511+
MLIRContext *context, RewritePatternSet &patterns, StringRef filterName,
512+
ArrayRef<int64_t> tileSizes,
513+
ArrayRef<DeviceMappingAttrInterface> mapping = {},
514+
ArrayRef<int64_t> interchange = {}) {
514515
scf::SCFTilingOptions tilingOptions;
515516
SmallVector<OpFoldResult> tileSizesOfr =
516517
getAsIndexOpFoldResult(context, tileSizes);
517518
tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange);
519+
tilingOptions.setMapping(mapping);
518520
TransformationFilter filter(StringAttr::get(context, filterName),
519521
StringAttr::get(context, "tiled"));
520522
patterns.add<TestTileUsingSCFForallOp>(context, tilingOptions, filter);
@@ -581,7 +583,10 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
581583
}
582584
if (testTilingForAll) {
583585
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
584-
addPatternForTilingUsingForall(context, patterns, "simple_gemm", {10, 20});
586+
addPatternForTilingUsingForall(
587+
context, patterns, "simple_gemm", {10, 20},
588+
{gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimY),
589+
gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimX)});
585590
// 2. Tiling 3D parallel generic op which implements a transpose.
586591
addPatternForTilingUsingForall(context, patterns,
587592
"parallel_generic_transpose", {10, 0, 20});

0 commit comments

Comments
 (0)