Skip to content

Commit 93eae9e

Browse files
committed
[MLIR] Add continuous tiling to TileUsingForOp
This patch adds continuous tiling options to TileUsingForOp. In cases where the tensor dimensions are not exactly divisible by the tile size, we are left with leftover tensor chunks that are irregularly tiled. This approach attempts to tile the leftover chunk with a smaller tile size and repeats this process recursively using exponentially diminishing tile sizes. The transform eventually generates a chain of loops that apply tiling using diminishing tile sizes. The transform lowers from the linalg dialect to scf dialect.
1 parent 5cb2ebc commit 93eae9e

File tree

6 files changed

+905
-11
lines changed

6 files changed

+905
-11
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1833,7 +1833,9 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
18331833
be as many handles as `ShapedType::kDynamic` values in the
18341834
`static_sizes` attribute. A static size of `0` indicates that the dimension
18351835
should not be tiled. No loop will be generated for such dimensions. If all
1836-
tile sizes are `0`, this transform is effectively a no-op.
1836+
tile sizes are `0`, this transform is effectively a no-op. To apply
1837+
continuous tiling `continuous_tiles` needs to be supplied with as many
1838+
boolean values as there are nested loops.
18371839

18381840
This op returns handles to the tiled op (in the generated loop nest) and the
18391841
generated loops. The number of loops is the number of tile sizes that are
@@ -1859,6 +1861,7 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
18591861
let arguments = (ins TransformHandleTypeInterface:$target,
18601862
Variadic<TransformAnyParamTypeOrAnyHandle>:$dynamic_sizes,
18611863
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$static_sizes,
1864+
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$continuous_tiles,
18621865
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$interchange,
18631866
DefaultValuedOptionalAttr<DenseBoolArrayAttr, "{}">:$scalable_sizes);
18641867
let results = (outs TransformHandleTypeInterface:$tiled_linalg_op,
@@ -1867,22 +1870,26 @@ def TileUsingForOp : Op<Transform_Dialect, "structured.tile_using_for",
18671870
OpBuilder<(ins "TypeRange":$loopTypes,
18681871
"Value":$target,
18691872
"ArrayRef<int64_t>":$staticTileSizes,
1873+
CArg<"ArrayRef<bool>", "{}">:$continuousTiles,
18701874
CArg<"ArrayRef<int64_t>", "{}">:$interchange,
18711875
CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
18721876
$scalableSizes)>,
18731877
OpBuilder<(ins "TypeRange":$loopTypes,
18741878
"Value":$target,
18751879
"ArrayRef<OpFoldResult>":$mixedTileSizes,
1880+
CArg<"ArrayRef<bool>", "{}">:$continuousTiles,
18761881
CArg<"ArrayRef<int64_t>", "{}">:$interchange,
18771882
CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
18781883
$scalableSizes)>,
18791884
OpBuilder<(ins "Value":$target,
18801885
"ArrayRef<int64_t>":$staticTileSizes,
1886+
CArg<"ArrayRef<bool>", "{}">:$continuousTiles,
18811887
CArg<"ArrayRef<int64_t>", "{}">:$interchange,
18821888
CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
18831889
$scalableSizes)>,
18841890
OpBuilder<(ins "Value":$target,
18851891
"ArrayRef<OpFoldResult>":$mixedTileSizes,
1892+
CArg<"ArrayRef<bool>", "{}">:$continuousTiles,
18861893
CArg<"ArrayRef<int64_t>", "{}">:$interchange,
18871894
CArg<"std::optional<ArrayRef<bool>>", "std::nullopt">:
18881895
$scalableSizes)>,

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ struct SCFTilingOptions {
7171
mapping, [](auto attr) -> Attribute { return attr; });
7272
return *this;
7373
}
74+
75+
/// Specify which loops in the loop nest are to be continuously tiled.
76+
SmallVector<bool> continuousTileMappingVector = {};
77+
SCFTilingOptions &setCTileMapping(ArrayRef<bool> ctile) {
78+
continuousTileMappingVector =
79+
llvm::map_to_vector(ctile, [](auto attr) -> bool { return attr; });
80+
return *this;
81+
}
7482
};
7583

7684
/// Transformation information returned after tiling.
@@ -92,6 +100,12 @@ FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
92100
TilingInterface op,
93101
const SCFTilingOptions &options);
94102

103+
/// Method to continuously tile an op that implements the `TilingInterface`
104+
/// using `scf.for` for iterating over the tiles.
105+
FailureOr<SCFTilingResult>
106+
continuousTileUsingSCF(RewriterBase &rewriter, TilingInterface op,
107+
const SCFTilingOptions &options);
108+
95109
/// Options used to control tile + fuse.
96110
struct SCFTileAndFuseOptions {
97111
/// The tiling options used to control the tiling of the consumer.

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,39 +2476,42 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
24762476
void transform::TileUsingForOp::build(
24772477
OpBuilder &builder, OperationState &result, TypeRange loopTypes,
24782478
Value target, ArrayRef<int64_t> staticTileSizes,
2479-
ArrayRef<int64_t> interchange,
2479+
ArrayRef<bool> continuousTiles, ArrayRef<int64_t> interchange,
24802480
std::optional<ArrayRef<bool>> scalableSizes) {
24812481
return build(builder, result, loopTypes,
24822482
/*target=*/target,
24832483
/*mixedTileSizes=*/
24842484
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2485-
interchange, scalableSizes);
2485+
continuousTiles, interchange, scalableSizes);
24862486
}
24872487

24882488
void transform::TileUsingForOp::build(
24892489
OpBuilder &builder, OperationState &result, Value target,
2490-
ArrayRef<int64_t> staticTileSizes, ArrayRef<int64_t> interchange,
2490+
ArrayRef<int64_t> staticTileSizes, ArrayRef<bool> continuousTiles,
2491+
ArrayRef<int64_t> interchange,
24912492
std::optional<ArrayRef<bool>> scalableSizes) {
24922493
build(builder, result, target,
24932494
getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
2494-
interchange, scalableSizes);
2495+
builder.getDenseBoolArrayAttr(continuousTiles), interchange,
2496+
scalableSizes);
24952497
}
24962498

24972499
void transform::TileUsingForOp::build(
24982500
OpBuilder &builder, OperationState &result, Value target,
2499-
ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<int64_t> interchange,
2501+
ArrayRef<OpFoldResult> mixedTileSizes, ArrayRef<bool> continuousTiles,
2502+
ArrayRef<int64_t> interchange,
25002503
std::optional<ArrayRef<bool>> scalableSizes) {
25012504
// Loop types are automaticaly splat by the callee, setting up one is
25022505
// enough.
25032506
SmallVector<Type> loopTypes(1, builder.getType<transform::AnyOpType>());
2504-
build(builder, result, loopTypes, target, mixedTileSizes, interchange,
2505-
scalableSizes);
2507+
build(builder, result, loopTypes, target, mixedTileSizes, continuousTiles,
2508+
interchange, scalableSizes);
25062509
}
25072510

25082511
void transform::TileUsingForOp::build(
25092512
OpBuilder &builder, OperationState &result, TypeRange loopTypes,
25102513
Value target, ArrayRef<OpFoldResult> mixedTileSizes,
2511-
ArrayRef<int64_t> interchange,
2514+
ArrayRef<bool> continuousTiles, ArrayRef<int64_t> interchange,
25122515
std::optional<ArrayRef<bool>> scalableSizes) {
25132516
SmallVector<int64_t> staticTileSizes;
25142517
SmallVector<Value> dynamicTileSizes;
@@ -2517,6 +2520,7 @@ void transform::TileUsingForOp::build(
25172520
// attributes for multiple variadic operands. In the absence of this,
25182521
// horrible bugs ensue.
25192522
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2523+
auto continuousTilesAttr = builder.getDenseBoolArrayAttr(continuousTiles);
25202524
unsigned numExpectedLoops =
25212525
staticTileSizes.size() - llvm::count(staticTileSizes, 0);
25222526
SmallVector<Type> resultTypes;
@@ -2535,6 +2539,7 @@ void transform::TileUsingForOp::build(
25352539
/*target=*/target,
25362540
/*dynamic_sizes=*/dynamicTileSizes,
25372541
/*static_sizes=*/staticTileSizesAttr,
2542+
/*continuous_tiles=*/continuousTilesAttr,
25382543
/*interchange=*/builder.getDenseI64ArrayAttr(interchange),
25392544
/*scalable_sizes=*/expandedScalableSizes);
25402545
}
@@ -2675,8 +2680,15 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
26752680
}
26762681

26772682
tilingOptions.setInterchange(getInterchange());
2678-
FailureOr<scf::SCFTilingResult> maybeTilingResult =
2679-
tileUsingSCF(rewriter, tilingInterface, tilingOptions);
2683+
tilingOptions.setCTileMapping(getContinuousTiles());
2684+
2685+
FailureOr<scf::SCFTilingResult> maybeTilingResult;
2686+
if (tilingOptions.continuousTileMappingVector.empty())
2687+
maybeTilingResult =
2688+
tileUsingSCF(rewriter, tilingInterface, tilingOptions);
2689+
else
2690+
maybeTilingResult =
2691+
continuousTileUsingSCF(rewriter, tilingInterface, tilingOptions);
26802692
if (failed(maybeTilingResult))
26812693
return DiagnosedSilenceableFailure::definiteFailure();
26822694

@@ -2725,6 +2737,18 @@ ParseResult parseOptionalInterchange(OpAsmParser &parser,
27252737
return success();
27262738
}
27272739

2740+
ParseResult parseOptionalContinuousTiles(OpAsmParser &parser,
2741+
OperationState &result) {
2742+
if (failed(parser.parseOptionalKeyword("continuous_tiles")))
2743+
return success();
2744+
if (failed(parser.parseEqual()))
2745+
return failure();
2746+
result.addAttribute(
2747+
transform::TileUsingForOp::getContinuousTilesAttrName(result.name),
2748+
DenseBoolArrayAttr::parse(parser, Type{}));
2749+
return success();
2750+
}
2751+
27282752
void printOptionalInterchange(OpAsmPrinter &p,
27292753
ArrayRef<int64_t> interchangeVals) {
27302754
if (!interchangeVals.empty()) {
@@ -2747,6 +2771,7 @@ ParseResult transform::TileUsingForOp::parse(OpAsmParser &parser,
27472771
if (parser.parseOperand(target) || parser.getCurrentLocation(&operandLoc) ||
27482772
parseDynamicIndexList(parser, dynamicSizes, staticSizes, scalableVals) ||
27492773
parseOptionalInterchange(parser, result) ||
2774+
parseOptionalContinuousTiles(parser, result) ||
27502775
parser.parseOptionalAttrDict(result.attributes) ||
27512776
parser.parseColonType(functionalType))
27522777
return ParseResult::failure();

0 commit comments

Comments
 (0)