Skip to content

Commit 7bc956d

Browse files
[mlir][PartialReductionTilingInterface] Add support for ReductionTilingStrategy::PartialReductionOuterParallel in tileUsingSCF. (#143988)
Following up from #143467, this PR adds support for `ReductionTilingStrategy::PartialReductionOuterParallel` to `tileUsingSCF`. The implementation of `PartialReductionTilingInterface` for `Linalg` ops has been updated to support this strategy as well. This makes the `tileUsingSCF` come on par with `linalg::tileReductionUsingForall` which will be deprecated subsequently. Changes summary - `PartialReductionTilingInterface` changes : - `tileToPartialReduction` method needed to get the induction variables of the generated tile loops. This was needed to keep the generated code similar to `linalg::tileReductionUsingForall`, specifically to create a simplified access for slicing the intermediate partial results tensor when tiled in `num_threads` mode. - `getPartialResultTilePosition` methods needs the induction varialbes for the generated tile loops for the same reason above, and also needs the `tilingStrategy` to be passed in to generate correct code. The tests in `transform-tile-reduction.mlir` testing the `linalg::tileReductionUsingForall` have been moved over to test `scf::tileUsingSCF` with `ReductionTilingStrategy::PartialReductionOuterParallel` strategy. Some of the test that were doing further cyclic distribution of the transformed code from tiling are removed. Those seem like two separate transformation that were merged into one. Ideally that would need to happen when resolving the `scf.forall` rather than during tiling. Please review only the top commit. Depends on #143467 Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 6c232f4 commit 7bc956d

File tree

9 files changed

+535
-265
lines changed

9 files changed

+535
-265
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2019,6 +2019,7 @@ def TileReductionUsingForallOp :
20192019

20202020
// TODO: support mixed static-dynamic (see TileUsingForallOp).
20212021
let arguments = (ins TransformHandleTypeInterface:$target,
2022+
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
20222023
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
20232024
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
20242025
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
@@ -2036,10 +2037,11 @@ def TileReductionUsingForallOp :
20362037

20372038
let assemblyFormat = [{
20382039
$target
2040+
(`reduction_dims` `=` $reduction_dims^)?
20392041
`by`
20402042
(`num_threads` `=` $num_threads^)?
2041-
(`,` `tile_sizes` `=` $tile_sizes^)?
2042-
(`,` `mapping` `=` $mapping^)?
2043+
(`tile_sizes` `=` $tile_sizes^)?
2044+
(`mapping` `=` $mapping^)?
20432045
attr-dict
20442046
`:` functional-type(operands, results)
20452047
}];

mlir/include/mlir/Dialect/Utils/StaticValueUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
156156
/// corresponding pair of arrays. This is the inverse function of
157157
/// `getMixedValues`.
158158
std::pair<SmallVector<int64_t>, SmallVector<Value>>
159-
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
159+
decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);
160160

161161
/// Helper to sort `values` according to matching `keys`.
162162
SmallVector<Value>

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -367,23 +367,28 @@ def PartialReductionOpInterface :
367367
OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
368368
let description = [{
369369
Interface for allowing operations to expose information needed to
370-
tile reductions using partial reduction followed by merge. This is
371-
complementary to TilingInterface to tile reductions.
370+
tile reductions using partial reduction followed by merge. This
371+
extends the `TilingInterface` to allow splitting a reduction
372+
dimension into a parallel dimension and reduction dimension.
373+
The materialized inter-tile loop could either be the reduction dimension
374+
(i.e. `ReductionTilingStrategy::PartialReductionOuterReduction`) or
375+
the parallel dimension (i.e
376+
`ReductionTilingStrategy::PartialReductionOuterReduction`).
372377
}];
373378
let cppNamespace = "::mlir";
374379
let methods = [
375380
InterfaceMethod<
376381
/*desc=*/[{
377382
Method to generate a tensor initalized with the identity value of the
378-
operation reduction. The tensor shape is equal to operation result
383+
reduction operator. The tensor shape is equal to operation result
379384
shape with new dimension for each non zero tile size.
380385
}],
381386
/*retType=*/"::mlir::FailureOr<SmallVector<Value>>",
382387
/*methodName=*/"generateInitialTensorForPartialReduction",
383388
/*args=*/(ins
384389
"::mlir::OpBuilder &":$b,
385390
"Location":$loc,
386-
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
391+
"::mlir::ArrayRef<::mlir::OpFoldResult>":$tileSizes,
387392
"const ::mlir::SetVector<unsigned> &":$reductionDims),
388393
/*methodBody=*/"",
389394
/*defaultImplementation=*/[{
@@ -396,6 +401,11 @@ def PartialReductionOpInterface :
396401
reduction dimension are converted to parallel dimensions with a size
397402
less or equal to the tile size. This is meant to be used with
398403
`mergeReductions` method which will combine the partial reductions.
404+
The method recieves the `offset` and `sizes` for all iteration space
405+
dimensions, as well as the iteration number of the tiled reduction
406+
dimensions (which is the induction variable of the inter-tile loop
407+
for the reduction dimension divided by the step of the loop) in
408+
`splitReductionIvs`.
399409
}],
400410
/*retType=*/"::mlir::FailureOr<TilingResult>",
401411
/*methodName=*/"tileToPartialReduction",
@@ -406,7 +416,8 @@ def PartialReductionOpInterface :
406416
"ValueRange":$init,
407417
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
408418
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
409-
"const ::llvm::SetVector<unsigned> &":$reductionDims),
419+
"const ::llvm::SetVector<unsigned> &":$reductionDims,
420+
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs),
410421
/*methodBody=*/"",
411422
/*defaultImplementation=*/[{
412423
return failure();
@@ -436,15 +447,22 @@ def PartialReductionOpInterface :
436447
the tiled operation. This is same as
437448
TilingInterface:::getResultTilePosition, but determines the result
438449
tile position for partial reduction.
450+
The method recieves the `offset` and `sizes` for all iteration space
451+
dimensions, as well as the iteration number of the tiled reduction
452+
dimensions (which is the induction variable of the inter-tile loop
453+
for the reduction dimension divided by the tile size specified) in
454+
`splitReductionIvs`.
439455
}],
440456
/*retType=*/"::llvm::LogicalResult",
441457
/*methodName=*/"getPartialResultTilePosition",
442458
/*args=*/(ins
443459
"::mlir::OpBuilder &":$b,
444460
"unsigned":$resultNumber,
461+
"ReductionTilingStrategy":$tilingStrategy,
445462
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
446463
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
447464
"const ::mlir::SetVector<unsigned> &":$reductionDims,
465+
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs,
448466
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
449467
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
450468
/*methodBody=*/"",

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3022,6 +3022,7 @@ void transform::TileReductionUsingForallOp::build(
30223022
build(builder, result,
30233023
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
30243024
/*target=*/target,
3025+
/*reduction_dims=*/{},
30253026
/*num_threads=*/staticNumThreadsAttr,
30263027
/*tile_sizes=*/staticTileSizesAttr,
30273028
/*mapping=*/mapping);
@@ -3036,23 +3037,45 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
30363037
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
30373038
SmallVector<OpFoldResult> tileSizes =
30383039
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
3039-
FailureOr<linalg::ForallReductionTilingResult> result =
3040-
linalg::tileReductionUsingForall(
3041-
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
3042-
numThreads, tileSizes, getMapping());
3040+
3041+
scf::SCFTilingOptions options;
3042+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
3043+
options.setReductionTilingStrategy(
3044+
ReductionTilingStrategy::PartialReductionOuterParallel);
3045+
if (!getNumThreads().empty()) {
3046+
options.setNumThreads(numThreads);
3047+
} else {
3048+
options.setTileSizes(tileSizes);
3049+
}
3050+
if (auto mapping = getMapping()) {
3051+
options.setMapping(mapping.value().getValue());
3052+
}
3053+
SmallVector<unsigned> reductionDims =
3054+
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
3055+
if (reductionDims.empty()) {
3056+
for (auto [idx, iteratorType] :
3057+
llvm::enumerate(target.getIteratorTypesArray())) {
3058+
if (iteratorType == utils::IteratorType::reduction)
3059+
reductionDims.push_back(idx);
3060+
}
3061+
}
3062+
options.setReductionDims(reductionDims);
3063+
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
3064+
rewriter, cast<TilingInterface>(target.getOperation()), options);
30433065

30443066
if (failed(result)) {
30453067
auto diag = emitSilenceableError() << "could not tile reduction";
3046-
diag.attachNote(target.getLoc()) << "target operation";
30473068
return diag;
30483069
}
3070+
rewriter.replaceOp(target, result->replacements);
3071+
30493072
for (Value initValue : result->initialValues)
30503073
results.push_back(initValue.getDefiningOp());
3051-
for (auto parallelTiledOp : result->parallelTiledOps)
3074+
for (auto parallelTiledOp : result->tiledOps)
30523075
results.push_back(parallelTiledOp);
30533076
for (auto mergeOp : result->mergeOps)
30543077
results.push_back(mergeOp);
3055-
results.push_back(result->loops);
3078+
results.push_back(result->loops.front());
30563079
return DiagnosedSilenceableFailure::success();
30573080
}
30583081

0 commit comments

Comments
 (0)