Skip to content

Commit a9efcbf

Browse files
authored
[MLIR] Add continuous tiling to transform dialect (llvm#82792)
This patch enables continuous tiling of a target structured op using diminishing tile sizes. In cases where the tensor dimensions are not exactly divisible by the tile size, we are left with leftover tensor chunks that are irregularly tiled. This approach enables tiling of the leftover chunk with a smaller tile size and repeats this process recursively using exponentially diminishing tile sizes. This eventually generates a chain of loops that apply tiling using diminishing tile sizes. Adds `continuous_tile_sizes` op to the transform dialect. This op, when given a tile size and a dimension, computes a series of diminishing tile sizes that can be used to tile the target along the given dimension. Additionally, this op also generates a series of chunk sizes that the corresponding tile sizes should be applied to along the given dimension. Adds `multiway` attribute to `transform.structured.split` that enables multiway splitting of a single target op along the given dimension, as specified in a list enumerating the chunk sizes.
1 parent 74a105a commit a9efcbf

File tree

11 files changed

+839
-98
lines changed

11 files changed

+839
-98
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 72 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,29 +1396,43 @@ def SplitOp : Op<Transform_Dialect, "structured.split",
13961396
DeclareOpInterfaceMethods<TransformOpInterface>,
13971397
ReportTrackingListenerFailuresOpTrait]> {
13981398
let description = [{
1399-
Indicates that the given `target` op should be split into two complementary
1399+
Splits the given `target` op into two or more complementary
14001400
parts, which combined cover the entire iteration domain of the original op.
14011401
The split is performed along the iteration space dimension provided as
1402-
attribute. In case of dimension overflow, the transformation fails. The
1403-
split is performed at the dimension iterator value specified as either the
1404-
static split point attribute when it is known at transform IR construction
1405-
time or as the handle to an operation producing a single index-typed value
1406-
when it is computed by payload IR. In the latter case, the static split
1402+
chunk size attribute specifying the size of the lower part; the remaining
1403+
range in the iteration space is assigned as the upper part. In case of
1404+
dimension overflow, the transformation fails. The split is performed at the
1405+
dimension iterator value specified as either the static chunk size
1406+
attribute when it is known at transform IR construction time or
1407+
as the handle to an operation producing a single index-typed value
1408+
when it is computed by payload IR. In the latter case, the chunk size
14071409
point must be set to `ShapedType::kDynamic` and the dynamic size handle
14081410
must point to as many value-producing operations as there are structured
14091411
operations pointed to by the target handle.
14101412

1411-
The operation consumes the target handle, but preserves the split point
1412-
handle if provided. It produces two new handles pointing to the two parts
1413-
of the structured op after splitting, in the same order as the target
1414-
operand, with the first handle corresponding to the part with lower
1415-
iteration space indices.
1413+
The operation consumes the target handle, but preserves the chunk size
1414+
handle if provided. Without the `multiway` attribute, it produces two
1415+
new handles pointing to the two parts of the structured op after splitting,
1416+
in the same order as the target operand, with the first handle
1417+
corresponding to the part with lower iteration space indices.
1418+
1419+
Multiway split mode is enabled by specifying the `multiway` attribute.
1420+
In this mode a single `target` op is split into multiple parts covering
1421+
the iteration space of the specified dimension. `static_chunk_sizes` and
1422+
`dynamic_chunk_sizes` in this case is a list of chunk sizes that the given
1423+
dimension should be split into. With `multiway` it produces two handles;
1424+
the first handle is a list of the multiple parts of the structured op
1425+
after splitting, where the target dimensions for each linalg op in the
1426+
list corresponds to the chunk sizes specfied in the input split list.
1427+
If the chunk sizes do not cover the entire iteration space, the leftover
1428+
chunk is the last payload in the first handle. The second handle is empty.
14161429
}];
14171430

14181431
let arguments = (ins TransformHandleTypeInterface:$target,
14191432
I64Attr:$dimension,
1420-
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_split_point,
1421-
I64Attr:$static_split_point);
1433+
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_chunk_sizes,
1434+
I64Attr:$static_chunk_sizes,
1435+
UnitAttr:$multiway);
14221436
let results = (outs TransformHandleTypeInterface:$first,
14231437
TransformHandleTypeInterface:$second);
14241438
let hasCustomAssemblyFormat = 1;
@@ -1819,6 +1833,51 @@ def TileReductionUsingForallOp :
18191833

18201834
}
18211835

1836+
//===----------------------------------------------------------------------===//
1837+
// ContinuousTileSizesOp
1838+
//===----------------------------------------------------------------------===//
1839+
1840+
def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
1841+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1842+
DeclareOpInterfaceMethods<TransformOpInterface>,
1843+
ReportTrackingListenerFailuresOpTrait]> {
1844+
let description = [{
1845+
This transform emits the IR computing the list of (1) exponentially
1846+
diminishing tile sizes that are powers of 2; and (2) the corresponding
1847+
chunk-sizes the target op should be split into along the given dimension.
1848+
1849+
For example, for `target_size` 9, and `dimension` 0 for the following
1850+
linalg op as target
1851+
1852+
```
1853+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
1854+
outs(%arg2: tensor<25x25xf32>)
1855+
```
1856+
1857+
the first result `tile_sizes` will be a list of diminishing tile sizes
1858+
9, 4, 2, 1; and the second result will be a list of chunk sizes
1859+
18, 4, 2, 1 that the corresponding dimension should be split into.
1860+
1861+
After the target op has been split along the given dimension (for example
1862+
using multiway split), each chunk can be tiled with the corresponding tile
1863+
size in the `tile_sizes` list generated as a result of this op.
1864+
1865+
Specifying the output type as !transform.param<i64> will cause `tile_sizes`
1866+
and `chunk_sizes` to be computed statically and not dynamically.
1867+
}];
1868+
1869+
let arguments = (ins TransformHandleTypeInterface:$target,
1870+
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
1871+
ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
1872+
let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
1873+
TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
1874+
let hasVerifier = 1;
1875+
let assemblyFormat =
1876+
"$target attr-dict `:` custom<ContinuousTileSizeTypes>("
1877+
"type($target), type($tile_sizes), type($chunk_sizes))";
1878+
1879+
}
1880+
18221881
//===----------------------------------------------------------------------===//
18231882
// TileUsingForOp
18241883
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
801801
/// Number of tiles associated with each size.
802802
T lowTripCount, highTripCount;
803803
};
804+
805+
template <typename T>
806+
struct ContinuousTileSizeSpecificationBase {
807+
/// Tile sizes.
808+
SmallVector<T> tileSizes;
809+
/// Number of tiles associated with each size.
810+
SmallVector<T> tripCounts;
811+
};
812+
804813
} // namespace detail
805814

806815
/// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
811820
struct StaticMultiSizeSpecification
812821
: public detail::MultiSizeSpecificationBase<int64_t> {};
813822

823+
struct ContinuousTileSizeSpecification
824+
: public detail::ContinuousTileSizeSpecificationBase<Value> {};
825+
struct StaticContinuousTileSizeSpecification
826+
: public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
827+
814828
/// Emits the IR computing the multi-sized tiling specification with two tile
815829
/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
816830
/// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,13 @@ FailureOr<StaticMultiSizeSpecification>
846860
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
847861
int64_t divisor);
848862

863+
FailureOr<StaticContinuousTileSizeSpecification>
864+
computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
865+
unsigned targetSize);
866+
FailureOr<ContinuousTileSizeSpecification>
867+
computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
868+
unsigned dimension, OpFoldResult targetSize,
869+
bool emitAssertions);
849870
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
850871
/// tiling by `numThreads`.
851872
/// If non-empty, the `mapping` is added as an attribute to the

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,10 @@ def ForeachOp : TransformDialectOp<"foreach",
624624
Each iteration gets executed by co-indexing the payloads of the arguments
625625
and mapping the body's arguments to these tuples, as though iterating over
626626
the zipped together `targets`. As such, in each iteration, the size of the
627-
payload of each of the body's block arguments is exactly one.
627+
payload of each of the body's block arguments is exactly one. The attribute
628+
`zip_shortest` can be used if the targets vary in their number of payloads;
629+
this will limit the iterations to only the number of payloads found in the
630+
shortest target.
628631

629632
This op always reads the target handles. Furthermore, it consumes a handle
630633
if there is a transform op in the body that consumes the corresponding
@@ -645,11 +648,12 @@ def ForeachOp : TransformDialectOp<"foreach",
645648
rollback capabilities.
646649
}];
647650

648-
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets);
651+
let arguments = (ins Variadic<Transform_AnyHandleOrParamType>:$targets,
652+
UnitAttr:$zip_shortest);
649653
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
650654
let regions = (region SizedRegion<1>:$body);
651655
let assemblyFormat =
652-
"$targets `:` type($targets) (`->` type($results)^)? $body attr-dict";
656+
"$targets attr-dict `:` type($targets) (`->` type($results)^)? $body";
653657
let hasVerifier = 1;
654658

655659
let extraClassDeclaration = [{

0 commit comments

Comments
 (0)