Skip to content

Commit 37f4d82

Browse files
committed
[MLIR] Add continuous tiling to Transform dialect
Add continuous tiling op structured.continuous_tile to the transform dialect that returns as result a list of exponentially diminishing tile sizes and a list of split points to do a multiway split of the target linalg op along the specified dimension.
1 parent b52fa94 commit 37f4d82

File tree

4 files changed

+350
-0
lines changed

4 files changed

+350
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,52 @@ def TileReductionUsingForallOp :
18191819

18201820
}
18211821

1822+
//===----------------------------------------------------------------------===//
1823+
// ContinuousTileSizesOp
1824+
//===----------------------------------------------------------------------===//
1825+
1826+
def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
1827+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1828+
DeclareOpInterfaceMethods<TransformOpInterface>,
1829+
ReportTrackingListenerFailuresOpTrait]> {
1830+
let description = [{
1831+
This transform takes a linalg as target and a dimension and target size
1832+
as attributes to generate a list of (1) exponentially diminishing
1833+
tile sizes that are powers of 2; and (2) the corresponding chunk-sizes
1834+
the linalg op should be split into along the given dimension.
1835+
1836+
For example, for `target_size` 9, and `dimension` 0 for the following
1837+
linalg op as target
1838+
1839+
```
1840+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
1841+
outs(%arg2: tensor<25x25xf32>)
1842+
```
1843+
1844+
the first result `tile_sizes` will be a list of diminishing tile sizes
1845+
9, 4, 2, 1; and the second result will be a list of chunk sizes
1846+
18, 4, 2, 1 that the corresponding dimension should be split into.
1847+
1848+
After the linalg has been split along the given dimension (for example using
1849+
multiway split), each chunk can be tiled with the corresponding tile size in
1850+
the `tile_sizes` list generated as a result of this op.
1851+
1852+
Specifying the output type as !transform.param<i64> will cause `tile_sizes`
1853+
and `split_points` to be computed statically and not dynamically.
1854+
}];
1855+
1856+
let arguments = (ins TransformHandleTypeInterface:$target,
1857+
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
1858+
ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
1859+
let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
1860+
TransformAnyParamTypeOrAnyHandle:$split_points);
1861+
let hasVerifier = 1;
1862+
let assemblyFormat =
1863+
"$target attr-dict `:` custom<ContinuousTileSizeTypes>("
1864+
"type($target), type($tile_sizes), type($split_points))";
1865+
1866+
}
1867+
18221868
//===----------------------------------------------------------------------===//
18231869
// TileUsingForOp
18241870
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
801801
/// Number of tiles associated with each size.
802802
T lowTripCount, highTripCount;
803803
};
804+
805+
template <typename T>
806+
struct ContinuousTileSizeSpecificationBase {
807+
/// Tile sizes.
808+
SmallVector<T> tileSizes;
809+
/// Number of tiles associated with each size.
810+
SmallVector<T> tripCounts;
811+
};
812+
804813
} // namespace detail
805814

806815
/// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
811820
struct StaticMultiSizeSpecification
812821
: public detail::MultiSizeSpecificationBase<int64_t> {};
813822

823+
struct ContinuousTileSizeSpecification
824+
: public detail::ContinuousTileSizeSpecificationBase<Value> {};
825+
struct StaticContinuousTileSizeSpecification
826+
: public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
827+
814828
/// Emits the IR computing the multi-sized tiling specification with two tile
815829
/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
816830
/// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,12 @@ FailureOr<StaticMultiSizeSpecification>
846860
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
847861
int64_t divisor);
848862

863+
FailureOr<StaticContinuousTileSizeSpecification>
864+
computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
865+
unsigned targetSize);
866+
FailureOr<ContinuousTileSizeSpecification>
867+
computeContinuousTileSizes(OpBuilder &builder, LinalgOp op, unsigned dimension,
868+
OpFoldResult targetSize, bool emitAssertions);
849869
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
850870
/// tiling by `numThreads`.
851871
/// If non-empty, the `mapping` is added as an attribute to the

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2581,6 +2581,157 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
25812581
return DiagnosedSilenceableFailure::success();
25822582
}
25832583

2584+
//===----------------------------------------------------------------------===//
2585+
// ContinuousTileSizesOp
2586+
//===----------------------------------------------------------------------===//
2587+
2588+
DiagnosedSilenceableFailure
2589+
transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2590+
TransformResults &transformResults,
2591+
TransformState &state) {
2592+
2593+
SmallVector<Operation *> targetOps =
2594+
llvm::to_vector(state.getPayloadOps(getTarget()));
2595+
2596+
if (!llvm::hasSingleElement(targetOps)) {
2597+
return emitDefiniteFailure() << "requires exactly one target (got "
2598+
<< llvm::range_size(targetOps) << ")";
2599+
}
2600+
2601+
auto target = dyn_cast<LinalgOp>(*targetOps.begin());
2602+
2603+
OpBuilder builder(target.getContext());
2604+
2605+
if (!target)
2606+
return emitDefiniteFailure() << "expected Linalg Op";
2607+
2608+
if (isa<TransformParamTypeInterface>(getSplitPoints().getType())) {
2609+
if (target.hasDynamicShape()) {
2610+
auto diag = emitSilenceableError()
2611+
<< "cannot compute parametric tile sizes for dynamically "
2612+
"shaped payload op";
2613+
diag.attachNote(target->getLoc()) << "payload op";
2614+
return diag;
2615+
}
2616+
2617+
FailureOr<StaticContinuousTileSizeSpecification> spec =
2618+
computeStaticContinuousTileSizes(target, getDimension(),
2619+
getTargetSize());
2620+
if (failed(spec)) {
2621+
return emitSilenceableError()
2622+
<< "failed to compute multi-size tiling sizes";
2623+
}
2624+
2625+
SmallVector<int64_t> splitPoints;
2626+
2627+
auto tileSizeTripCountPairs =
2628+
llvm::zip_equal(spec->tileSizes, spec->tripCounts);
2629+
2630+
for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs))
2631+
splitPoints.push_back(std::get<0>(pair) * std::get<1>(pair));
2632+
2633+
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2634+
return llvm::to_vector(
2635+
llvm::map_range(values, [&](int64_t value) -> Attribute {
2636+
return builder.getI64IntegerAttr(value);
2637+
}));
2638+
};
2639+
transformResults.setParams(cast<OpResult>(getTileSizes()),
2640+
makeI64AttrsFromI64(spec->tileSizes));
2641+
transformResults.setParams(cast<OpResult>(getSplitPoints()),
2642+
makeI64AttrsFromI64(splitPoints));
2643+
2644+
return DiagnosedSilenceableFailure::success();
2645+
}
2646+
2647+
builder.setInsertionPoint(target);
2648+
2649+
OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2650+
unsigned dimension = getDimension();
2651+
2652+
FailureOr<ContinuousTileSizeSpecification> spec =
2653+
computeContinuousTileSizes(builder, target, dimension, targetSize, true);
2654+
if (failed(spec)) {
2655+
return emitSilenceableError() << "could not generate tile size computation";
2656+
}
2657+
2658+
auto tileSizeTripCountPairs =
2659+
llvm::zip_equal(spec->tileSizes, spec->tripCounts);
2660+
2661+
AffineExpr s0 = builder.getAffineSymbolExpr(0);
2662+
AffineExpr s1 = builder.getAffineSymbolExpr(1);
2663+
auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2664+
return affine::makeComposedAffineApply(builder, target->getLoc(), expr,
2665+
ofrs);
2666+
};
2667+
2668+
SmallVector<Value> splitPoints;
2669+
Value splitPoint;
2670+
for (auto [idx, pair] : llvm::enumerate(tileSizeTripCountPairs)) {
2671+
splitPoint = apply(s0 * s1, {std::get<0>(pair), std::get<1>(pair)});
2672+
splitPoints.push_back(splitPoint);
2673+
}
2674+
2675+
auto makeOpFromValue = [&](ArrayRef<Value> values) {
2676+
return llvm::to_vector(
2677+
llvm::map_range(values, [&](Value value) -> Operation * {
2678+
return value.getDefiningOp();
2679+
}));
2680+
};
2681+
2682+
transformResults.set(cast<OpResult>(getTileSizes()),
2683+
makeOpFromValue(spec->tileSizes));
2684+
transformResults.set(cast<OpResult>(getSplitPoints()),
2685+
makeOpFromValue(splitPoints));
2686+
2687+
return DiagnosedSilenceableFailure::success();
2688+
}
2689+
2690+
LogicalResult transform::ContinuousTileSizesOp::verify() {
2691+
2692+
if (getTileSizes().getType() != getSplitPoints().getType()) {
2693+
return emitOpError() << "expects all results type to be the same";
2694+
}
2695+
2696+
return success();
2697+
}
2698+
2699+
void transform::ContinuousTileSizesOp::getEffects(
2700+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2701+
if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2702+
onlyReadsPayload(effects);
2703+
else
2704+
modifiesPayload(effects);
2705+
onlyReadsHandle(getTarget(), effects);
2706+
producesHandle(getTileSizes(), effects);
2707+
producesHandle(getSplitPoints(), effects);
2708+
}
2709+
2710+
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
2711+
Type targetType, Type tile_sizes,
2712+
Type) {
2713+
printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2714+
}
2715+
2716+
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2717+
Type &targetType,
2718+
Type &tileSizesType,
2719+
Type &splitPointsType) {
2720+
FunctionType funcType;
2721+
llvm::SMLoc typeLoc = parser.getCurrentLocation();
2722+
if (failed(parser.parseType<FunctionType>(funcType)))
2723+
return failure();
2724+
2725+
if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2726+
parser.emitError(typeLoc) << "expects a trailing functional type with one "
2727+
"argument and one result";
2728+
}
2729+
targetType = funcType.getInput(0);
2730+
tileSizesType = splitPointsType = funcType.getResult(0);
2731+
2732+
return success();
2733+
}
2734+
25842735
//===----------------------------------------------------------------------===//
25852736
// TileUsingForOp
25862737
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,139 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
107107
b.getStringAttr("expected strictly positive tile size and divisor"));
108108
}
109109

110+
FailureOr<StaticContinuousTileSizeSpecification>
111+
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
112+
unsigned targetSize) {
113+
114+
assert(!op.hasDynamicShape() &&
115+
"cannot compute static multi-tile sizes for an op with dynamic shape");
116+
assert(targetSize > 0 && "target size must be non-negative");
117+
assert(dimension < op.getNumLoops() && "dimension overflow");
118+
119+
StaticContinuousTileSizeSpecification spec;
120+
int64_t loopRange = op.getStaticLoopRanges()[dimension];
121+
int64_t tripCount = loopRange / targetSize;
122+
123+
unsigned tileSize = targetSize;
124+
125+
spec.tileSizes.push_back(tileSize);
126+
spec.tripCounts.push_back(tripCount);
127+
128+
int64_t remainderChunk = loopRange % targetSize;
129+
130+
while (tileSize > 1 && remainderChunk != 0) {
131+
132+
uint64_t maxPower = llvm::bit_floor(tileSize);
133+
tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
134+
135+
tripCount = remainderChunk / tileSize;
136+
137+
if (tripCount > 0) {
138+
spec.tileSizes.push_back(tileSize);
139+
spec.tripCounts.push_back(tripCount);
140+
}
141+
142+
remainderChunk = remainderChunk % tileSize;
143+
}
144+
145+
auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
146+
SmallVector<int64_t> tripCounts,
147+
int64_t range) -> bool {
148+
int64_t computedRange = 0;
149+
for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
150+
computedRange += tileSize * tripCount;
151+
return range == computedRange;
152+
};
153+
154+
if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
155+
return failure();
156+
157+
return spec;
158+
}
159+
160+
FailureOr<ContinuousTileSizeSpecification>
161+
mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, LinalgOp op,
162+
unsigned dimension,
163+
OpFoldResult targetSize,
164+
bool emitAssertions) {
165+
166+
// Bail out on dimension overflow.
167+
if (dimension >= op.getNumLoops())
168+
return failure();
169+
170+
// The code below works only on values.
171+
Location loc = op.getLoc();
172+
ImplicitLocOpBuilder b(loc, builder);
173+
if (emitAssertions) {
174+
emitIsPositiveIndexAssertion(b, targetSize);
175+
}
176+
Value targetSizeValue =
177+
getValueOrCreateConstantIndexOp(builder, loc, targetSize);
178+
179+
// Find the trip count of the iteration space dimension for which the tile
180+
// sizes are computed.
181+
SmallVector<OpFoldResult> allShapes =
182+
op.createFlatListOfOperandDims(b, b.getLoc());
183+
AffineMap shapesToLoops = op.getShapesToLoopsMap();
184+
SmallVector<OpFoldResult> loopRanges =
185+
makeComposedFoldedMultiResultAffineApply(b, op.getLoc(), shapesToLoops,
186+
allShapes);
187+
188+
Value loopRange =
189+
getValueOrCreateConstantIndexOp(b, op.getLoc(), loopRanges[dimension]);
190+
191+
ContinuousTileSizeSpecification spec;
192+
193+
// Compute the tile sizes and the respective numbers of tiles.
194+
AffineExpr s0 = b.getAffineSymbolExpr(0);
195+
AffineExpr s1 = b.getAffineSymbolExpr(1);
196+
auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
197+
return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
198+
};
199+
200+
Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
201+
Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
202+
203+
OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
204+
b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
205+
206+
uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
207+
208+
assert(tileSizeInt > 0 && "target size must be non-negative");
209+
210+
spec.tileSizes.push_back(targetSizeValue);
211+
spec.tripCounts.push_back(tripCountValue);
212+
213+
while (tileSizeInt > 1) {
214+
uint64_t maxPower = llvm::bit_floor(tileSizeInt);
215+
tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
216+
auto constStepOp =
217+
builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
218+
tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
219+
220+
tripCountSize = affine::makeComposedFoldedAffineApply(
221+
b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
222+
223+
// Optimization if tripCount can be determined to be zero.
224+
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
225+
auto intAttr = cast<IntegerAttr>(attr);
226+
bool isTripCountZero = intAttr.getValue().isZero();
227+
228+
if (!isTripCountZero) {
229+
spec.tileSizes.push_back(constStepOp);
230+
spec.tripCounts.push_back(tripCountValue);
231+
}
232+
} else {
233+
spec.tileSizes.push_back(constStepOp);
234+
spec.tripCounts.push_back(tripCountValue);
235+
}
236+
237+
remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
238+
}
239+
240+
return spec;
241+
}
242+
110243
FailureOr<StaticMultiSizeSpecification>
111244
mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
112245
int64_t targetSize, int64_t divisor) {

0 commit comments

Comments
 (0)