Skip to content

Commit a6155b6

Browse files
committed
[MLIR] Add continuous tiling to Transform dialect
Add continuous tiling op `structured.continuous_tile_sizes` to the transform dialect that returns as result (1) a list of exponentially diminishing tile sizes, and (2) a list of chunk sizes -- along the specified dimension of the target -- where the corresponding tile sizes from (1) can be applied. The list of chunk sizes from (2) cover the entire iteration space along the given dimension of the target.
1 parent a65771f commit a6155b6

File tree

4 files changed

+345
-0
lines changed

4 files changed

+345
-0
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,51 @@ def TileReductionUsingForallOp :
18191819

18201820
}
18211821

1822+
//===----------------------------------------------------------------------===//
1823+
// ContinuousTileSizesOp
1824+
//===----------------------------------------------------------------------===//
1825+
1826+
def ContinuousTileSizesOp : Op<Transform_Dialect, "structured.continuous_tile_sizes",
1827+
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
1828+
DeclareOpInterfaceMethods<TransformOpInterface>,
1829+
ReportTrackingListenerFailuresOpTrait]> {
1830+
let description = [{
1831+
This transform emits the IR computing the list of (1) exponentially
1832+
diminishing tile sizes that are powers of 2; and (2) the corresponding
1833+
chunk-sizes the target op should be split into along the given dimension.
1834+
1835+
For example, for `target_size` 9, and `dimension` 0 for the following
1836+
linalg op as target
1837+
1838+
```
1839+
%0 = linalg.matmul ins(%arg0, %arg1: tensor<25x34xf32>, tensor<34x25xf32>)
1840+
outs(%arg2: tensor<25x25xf32>)
1841+
```
1842+
1843+
the first result `tile_sizes` will be a list of diminishing tile sizes
1844+
9, 4, 2, 1; and the second result will be a list of chunk sizes
1845+
18, 4, 2, 1 that the corresponding dimension should be split into.
1846+
1847+
After the target op has been split along the given dimension (for example
1848+
using multiway split), each chunk can be tiled with the corresponding tile
1849+
size in the `tile_sizes` list generated as a result of this op.
1850+
1851+
Specifying the output type as !transform.param<i64> will cause `tile_sizes`
1852+
and `chunk_sizes` to be computed statically and not dynamically.
1853+
}];
1854+
1855+
let arguments = (ins TransformHandleTypeInterface:$target,
1856+
ConfinedAttr<I64Attr, [IntNonNegative]>:$dimension,
1857+
ConfinedAttr<I64Attr, [IntNonNegative]>:$target_size);
1858+
let results = (outs TransformAnyParamTypeOrAnyHandle:$tile_sizes,
1859+
TransformAnyParamTypeOrAnyHandle:$chunk_sizes);
1860+
let hasVerifier = 1;
1861+
let assemblyFormat =
1862+
"$target attr-dict `:` custom<ContinuousTileSizeTypes>("
1863+
"type($target), type($tile_sizes), type($chunk_sizes))";
1864+
1865+
}
1866+
18221867
//===----------------------------------------------------------------------===//
18231868
// TileUsingForOp
18241869
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,15 @@ struct MultiSizeSpecificationBase {
801801
/// Number of tiles associated with each size.
802802
T lowTripCount, highTripCount;
803803
};
804+
805+
template <typename T>
806+
struct ContinuousTileSizeSpecificationBase {
807+
/// Tile sizes.
808+
SmallVector<T> tileSizes;
809+
/// Number of tiles associated with each size.
810+
SmallVector<T> tripCounts;
811+
};
812+
804813
} // namespace detail
805814

806815
/// A description of a multi-size tiling comprising tile sizes and numbers of
@@ -811,6 +820,11 @@ struct MultiSizeSpecification
811820
struct StaticMultiSizeSpecification
812821
: public detail::MultiSizeSpecificationBase<int64_t> {};
813822

823+
struct ContinuousTileSizeSpecification
824+
: public detail::ContinuousTileSizeSpecificationBase<Value> {};
825+
struct StaticContinuousTileSizeSpecification
826+
: public detail::ContinuousTileSizeSpecificationBase<int64_t> {};
827+
814828
/// Emits the IR computing the multi-sized tiling specification with two tile
815829
/// sizes not exceeding `targetSize`, each divisible by `sizeDivisor`, such
816830
/// that there exist numbers of tiles with these sizes that fully cover the
@@ -846,6 +860,13 @@ FailureOr<StaticMultiSizeSpecification>
846860
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
847861
int64_t divisor);
848862

863+
FailureOr<StaticContinuousTileSizeSpecification>
864+
computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
865+
unsigned targetSize);
866+
FailureOr<ContinuousTileSizeSpecification>
867+
computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
868+
unsigned dimension, OpFoldResult targetSize,
869+
bool emitAssertions);
849870
/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
850871
/// tiling by `numThreads`.
851872
/// If non-empty, the `mapping` is added as an attribute to the

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2583,6 +2583,154 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
25832583
return DiagnosedSilenceableFailure::success();
25842584
}
25852585

2586+
//===----------------------------------------------------------------------===//
2587+
// ContinuousTileSizesOp
2588+
//===----------------------------------------------------------------------===//
2589+
2590+
DiagnosedSilenceableFailure
2591+
transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
2592+
TransformResults &transformResults,
2593+
TransformState &state) {
2594+
2595+
SmallVector<Operation *> targetOps =
2596+
llvm::to_vector(state.getPayloadOps(getTarget()));
2597+
2598+
if (!llvm::hasSingleElement(targetOps)) {
2599+
return mlir::emitSilenceableFailure(getLoc())
2600+
<< "requires exactly one target (got " << llvm::range_size(targetOps)
2601+
<< ")";
2602+
}
2603+
2604+
Operation *target = *targetOps.begin();
2605+
auto linalgOp = dyn_cast<LinalgOp>(target);
2606+
auto tileableOp = dyn_cast<TilingInterface>(target);
2607+
2608+
if (!linalgOp)
2609+
return emitDefiniteFailure() << "expected Linalg Op";
2610+
2611+
OpBuilder builder(linalgOp.getContext());
2612+
2613+
if (isa<TransformParamTypeInterface>(getChunkSizes().getType())) {
2614+
if (linalgOp.hasDynamicShape()) {
2615+
auto diag = emitSilenceableError()
2616+
<< "cannot compute parametric tile sizes for dynamically "
2617+
"shaped payload op";
2618+
diag.attachNote(linalgOp->getLoc()) << "payload op";
2619+
return diag;
2620+
}
2621+
2622+
FailureOr<StaticContinuousTileSizeSpecification> spec =
2623+
computeStaticContinuousTileSizes(linalgOp, getDimension(),
2624+
getTargetSize());
2625+
if (failed(spec)) {
2626+
return emitSilenceableError()
2627+
<< "failed to compute multi-size tiling sizes";
2628+
}
2629+
2630+
SmallVector<int64_t> chunkSizes;
2631+
2632+
for (auto &&[tileSize, tripCount] :
2633+
llvm::zip_equal(spec->tileSizes, spec->tripCounts))
2634+
chunkSizes.push_back(tileSize * tripCount);
2635+
2636+
auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
2637+
return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2638+
return builder.getI64IntegerAttr(value);
2639+
});
2640+
};
2641+
transformResults.setParams(cast<OpResult>(getTileSizes()),
2642+
getI64AttrsFromI64(spec->tileSizes));
2643+
transformResults.setParams(cast<OpResult>(getChunkSizes()),
2644+
getI64AttrsFromI64(chunkSizes));
2645+
2646+
return DiagnosedSilenceableFailure::success();
2647+
}
2648+
2649+
builder.setInsertionPoint(linalgOp);
2650+
2651+
OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
2652+
unsigned dimension = getDimension();
2653+
2654+
FailureOr<ContinuousTileSizeSpecification> spec = computeContinuousTileSizes(
2655+
builder, tileableOp, dimension, targetSize, true);
2656+
if (failed(spec)) {
2657+
return emitSilenceableError() << "could not generate tile size computation";
2658+
}
2659+
2660+
AffineExpr s0 = builder.getAffineSymbolExpr(0);
2661+
AffineExpr s1 = builder.getAffineSymbolExpr(1);
2662+
auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
2663+
return affine::makeComposedAffineApply(builder, linalgOp->getLoc(), expr,
2664+
ofrs);
2665+
};
2666+
2667+
SmallVector<Value> chunkSizes;
2668+
Value splitPoint;
2669+
for (auto &&[tileSize, tripCount] :
2670+
llvm::zip_equal(spec->tileSizes, spec->tripCounts)) {
2671+
splitPoint = apply(s0 * s1, {tileSize, tripCount});
2672+
chunkSizes.push_back(splitPoint);
2673+
}
2674+
2675+
auto getDefiningOps = [&](ArrayRef<Value> values) {
2676+
return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2677+
return value.getDefiningOp();
2678+
});
2679+
};
2680+
2681+
transformResults.set(cast<OpResult>(getTileSizes()),
2682+
getDefiningOps(spec->tileSizes));
2683+
transformResults.set(cast<OpResult>(getChunkSizes()),
2684+
getDefiningOps(chunkSizes));
2685+
2686+
return DiagnosedSilenceableFailure::success();
2687+
}
2688+
2689+
LogicalResult transform::ContinuousTileSizesOp::verify() {
2690+
2691+
if (getTileSizes().getType() != getChunkSizes().getType()) {
2692+
return emitOpError() << "expects all results type to be the same";
2693+
}
2694+
2695+
return success();
2696+
}
2697+
2698+
void transform::ContinuousTileSizesOp::getEffects(
2699+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2700+
if (isa<TransformParamTypeInterface>(getTileSizes().getType()))
2701+
onlyReadsPayload(effects);
2702+
else
2703+
modifiesPayload(effects);
2704+
onlyReadsHandle(getTarget(), effects);
2705+
producesHandle(getTileSizes(), effects);
2706+
producesHandle(getChunkSizes(), effects);
2707+
}
2708+
2709+
static void printContinuousTileSizeTypes(OpAsmPrinter &printer, Operation *op,
2710+
Type targetType, Type tile_sizes,
2711+
Type) {
2712+
printer.printFunctionalType(TypeRange{targetType}, TypeRange{tile_sizes});
2713+
}
2714+
2715+
static ParseResult parseContinuousTileSizeTypes(OpAsmParser &parser,
2716+
Type &targetType,
2717+
Type &tileSizesType,
2718+
Type &chunkSizesType) {
2719+
FunctionType funcType;
2720+
llvm::SMLoc typeLoc = parser.getCurrentLocation();
2721+
if (failed(parser.parseType<FunctionType>(funcType)))
2722+
return failure();
2723+
2724+
if (funcType.getNumInputs() != 1 || funcType.getNumResults() != 1) {
2725+
parser.emitError(typeLoc) << "expects a trailing functional type with one "
2726+
"argument and one result";
2727+
}
2728+
targetType = funcType.getInput(0);
2729+
tileSizesType = chunkSizesType = funcType.getResult(0);
2730+
2731+
return success();
2732+
}
2733+
25862734
//===----------------------------------------------------------------------===//
25872735
// TileUsingForOp
25882736
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,137 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
107107
b.getStringAttr("expected strictly positive tile size and divisor"));
108108
}
109109

110+
FailureOr<StaticContinuousTileSizeSpecification>
111+
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
112+
unsigned dimension,
113+
unsigned targetSize) {
114+
115+
assert(!op.hasDynamicShape() &&
116+
"cannot compute static multi-tile sizes for an op with dynamic shape");
117+
assert(targetSize > 0 && "target size must be non-negative");
118+
assert(dimension < op.getNumLoops() && "dimension overflow");
119+
120+
StaticContinuousTileSizeSpecification spec;
121+
int64_t loopRange = op.getStaticLoopRanges()[dimension];
122+
int64_t tripCount = loopRange / targetSize;
123+
124+
unsigned tileSize = targetSize;
125+
126+
spec.tileSizes.push_back(tileSize);
127+
spec.tripCounts.push_back(tripCount);
128+
129+
int64_t remainderChunk = loopRange % targetSize;
130+
131+
while (tileSize > 1 && remainderChunk != 0) {
132+
133+
uint64_t maxPower = llvm::bit_floor(tileSize);
134+
tileSize = maxPower == tileSize ? maxPower >> 1 : maxPower;
135+
136+
tripCount = remainderChunk / tileSize;
137+
138+
if (tripCount > 0) {
139+
spec.tileSizes.push_back(tileSize);
140+
spec.tripCounts.push_back(tripCount);
141+
}
142+
143+
remainderChunk = remainderChunk % tileSize;
144+
}
145+
146+
auto tripCountCheck = [&](SmallVector<int64_t> tileSizes,
147+
SmallVector<int64_t> tripCounts,
148+
int64_t range) -> bool {
149+
int64_t computedRange = 0;
150+
for (auto [tileSize, tripCount] : llvm::zip(tileSizes, tripCounts))
151+
computedRange += tileSize * tripCount;
152+
return range == computedRange;
153+
};
154+
155+
if (!tripCountCheck(spec.tileSizes, spec.tripCounts, loopRange))
156+
return failure();
157+
158+
return spec;
159+
}
160+
161+
FailureOr<ContinuousTileSizeSpecification>
162+
mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
163+
unsigned dimension,
164+
OpFoldResult targetSize,
165+
bool emitAssertions) {
166+
167+
SmallVector<Range> loopRanges = op.getIterationDomain(builder);
168+
unsigned numLoops = loopRanges.size();
169+
170+
// Bail out on dimension overflow.
171+
if (dimension >= numLoops)
172+
return failure();
173+
174+
// The code below works only on values.
175+
Location loc = op->getLoc();
176+
ImplicitLocOpBuilder b(loc, builder);
177+
if (emitAssertions) {
178+
emitIsPositiveIndexAssertion(b, targetSize);
179+
}
180+
Value targetSizeValue =
181+
getValueOrCreateConstantIndexOp(builder, loc, targetSize);
182+
183+
// Find the trip count of the iteration space dimension for which the tile
184+
// sizes are computed.
185+
Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
186+
loopRanges[dimension].size);
187+
ContinuousTileSizeSpecification spec;
188+
189+
// Compute the tile sizes and the respective numbers of tiles.
190+
AffineExpr s0 = b.getAffineSymbolExpr(0);
191+
AffineExpr s1 = b.getAffineSymbolExpr(1);
192+
auto apply = [&](AffineExpr expr, ArrayRef<OpFoldResult> ofrs) -> Value {
193+
return affine::makeComposedAffineApply(b, b.getLoc(), expr, ofrs);
194+
};
195+
196+
Value tripCountValue = apply(s0.floorDiv(s1), {loopRange, targetSizeValue});
197+
Value remainderChunkValue = apply(s0 % s1, {loopRange, targetSizeValue});
198+
199+
OpFoldResult tripCountSize = affine::makeComposedFoldedAffineApply(
200+
b, b.getLoc(), s0.floorDiv(s1), {loopRange, targetSizeValue});
201+
202+
// emitAssertions above already asserts that targetSize is
203+
// a poistive integer.
204+
uint64_t tileSizeInt = *getConstantIntValue(targetSizeValue);
205+
206+
assert(tileSizeInt > 0 && "target size must be non-negative");
207+
208+
spec.tileSizes.push_back(targetSizeValue);
209+
spec.tripCounts.push_back(tripCountValue);
210+
211+
while (tileSizeInt > 1) {
212+
uint64_t maxPower = llvm::bit_floor(tileSizeInt);
213+
tileSizeInt = maxPower == tileSizeInt ? maxPower >> 1 : maxPower;
214+
auto constStepOp =
215+
builder.createOrFold<arith::ConstantIndexOp>(b.getLoc(), tileSizeInt);
216+
tripCountValue = apply(s0.floorDiv(s1), {remainderChunkValue, constStepOp});
217+
218+
tripCountSize = affine::makeComposedFoldedAffineApply(
219+
b, b.getLoc(), s0.floorDiv(s1), {remainderChunkValue, constStepOp});
220+
221+
// Optimization if tripCount can be determined to be zero.
222+
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tripCountSize)) {
223+
auto intAttr = cast<IntegerAttr>(attr);
224+
bool isTripCountZero = intAttr.getValue().isZero();
225+
226+
if (!isTripCountZero) {
227+
spec.tileSizes.push_back(constStepOp);
228+
spec.tripCounts.push_back(tripCountValue);
229+
}
230+
} else {
231+
spec.tileSizes.push_back(constStepOp);
232+
spec.tripCounts.push_back(tripCountValue);
233+
}
234+
235+
remainderChunkValue = apply(s0 % s1, {remainderChunkValue, constStepOp});
236+
}
237+
238+
return spec;
239+
}
240+
110241
FailureOr<StaticMultiSizeSpecification>
111242
mlir::linalg::computeStaticMultiTileSizes(LinalgOp op, unsigned dimension,
112243
int64_t targetSize, int64_t divisor) {

0 commit comments

Comments
 (0)