Skip to content

Commit 1be04a6

Browse files
[mlir][PartialReductionTilingInterface] Generalize implementation of tileUsingSCF for ReductionTilingStrategy::PartialOuterReduction.
This is a precursor to generalizing the `tileUsingSCF` to handle `ReductionTilingStrategy::PartialOuterParallel` strategy. This change itself is generalizing/refactoring the current implementation that supports only `ReductionTilingStrategy::PartialOuterReduction`. Changes in this PR - Move the `ReductionTilingStrategy` enum out of `scf::SCFTilingOptions` and make them visible to `TilingInterface`. - `PartialTilingInterface` changes - Pass the `tilingStrategy` used for partial reduction to `tileToPartialReduction`. - Pass the reduction dimension along as `const llvm::SetVector<unsigned> &`. - Allow `scf::SCFTilingOptions` to set the reduction dimensions that are to be tiled. - Change `structured.tiled_reduction_using_for` to allow specification of the reduction dimensions to be partially tiled. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 97cc2b1 commit 1be04a6

File tree

9 files changed

+438
-255
lines changed

9 files changed

+438
-255
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
17671767
- the result-combining op,
17681768
- the parent `for` op.
17691769

1770+
The `reduction_dims` can be used to specify the subset of reduction dimensions
1771+
of the operation to tile. If left unspecified, all reduction dimensions are
1772+
tiled.
1773+
17701774
#### Example:
17711775

17721776
```
@@ -1817,7 +1821,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
18171821

18181822
// TODO: support mixed static-dynamic (see TileUsingForallOp).
18191823
let arguments = (ins TransformHandleTypeInterface:$target,
1820-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1824+
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
1825+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
18211826
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
18221827
TransformHandleTypeInterface:$split_op,
18231828
TransformHandleTypeInterface:$combining_op,
@@ -1830,6 +1835,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
18301835

18311836
let assemblyFormat = [{
18321837
$target
1838+
(`reduction_dims` `=` $reduction_dims^)?
18331839
`by` `tile_sizes` `=` $tile_sizes
18341840
attr-dict
18351841
`:` functional-type(operands, results)

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
8585
return *this;
8686
}
8787

88+
/// Specify mapping of loops to devices. This is only respected when the loop
89+
/// constructs support such a mapping (like `scf.forall`). Will be ignored
90+
/// when using loop constructs that dont support such a mapping (like
91+
/// `scf.for`)
92+
SmallVector<Attribute> mappingVector = {};
93+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
94+
mappingVector = llvm::to_vector(mapping);
95+
return *this;
96+
}
97+
98+
//-------------------------------------------------------------------------//
99+
// Options related reduction tiling
100+
//-------------------------------------------------------------------------//
101+
88102
/// Specify how reduction dimensions should be tiled.
89-
///
90-
/// Tiling can be thought of as splitting a dimension into 2 and materializing
91-
/// the outer dimension as a loop:
92-
///
93-
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94-
///
95-
/// For parallel dimensions, the split can only happen in one way, with both
96-
/// dimensions being parallel. For reduction dimensions however, there is a
97-
/// choice in how we split the reduction dimension. This enum exposes this
98-
/// choice.
99-
enum class ReductionTilingStrategy {
100-
// [reduction] -> [reduction1, reduction2]
101-
// -> loop[reduction1] { [reduction2] }
102-
FullReduction,
103-
// [reduction] -> [reduction1, parallel2]
104-
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
105-
PartialReductionOuterReduction,
106-
// [reduction] -> [parallel1, reduction2]
107-
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
108-
PartialReductionOuterParallel
109-
};
110103
ReductionTilingStrategy reductionStrategy =
111104
ReductionTilingStrategy::FullReduction;
112105
SCFTilingOptions &
@@ -115,13 +108,10 @@ struct SCFTilingOptions {
115108
return *this;
116109
}
117110

118-
/// Specify mapping of loops to devices. This is only respected when the loop
119-
/// constructs support such a mapping (like `scf.forall`). Will be ignored
120-
/// when using loop constructs that dont support such a mapping (like
121-
/// `scf.for`)
122-
SmallVector<Attribute> mappingVector = {};
123-
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
124-
mappingVector = llvm::to_vector(mapping);
111+
/// Specify the reduction dimensions to be tiled.
112+
SetVector<unsigned> reductionDims;
113+
SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
114+
reductionDims.insert(dims.begin(), dims.end());
125115
return *this;
126116
}
127117
};

mlir/include/mlir/Interfaces/TilingInterface.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ struct TilingResult {
3636
SmallVector<Operation *> generatedSlices;
3737
};
3838

39+
/// Tiling can be thought of as splitting a dimension into 2 and
40+
/// materializing the outer dimension as a loop:
41+
///
42+
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
43+
///
44+
/// For parallel dimensions, the split can only happen in one way, with both
45+
/// dimensions being parallel. For reduction dimensions however, there is a
46+
/// choice in how we split the reduction dimension. This enum exposes this
47+
/// choice.
48+
enum class ReductionTilingStrategy {
49+
// [reduction] -> [reduction1, reduction2]
50+
// -> loop[reduction1] { [reduction2] }
51+
FullReduction,
52+
// [reduction] -> [reduction1, parallel2]
53+
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
54+
PartialReductionOuterReduction,
55+
// [reduction] -> [parallel1, reduction2]
56+
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
57+
PartialReductionOuterParallel
58+
};
59+
3960
/// Container for the result of merge operation of tiling.
4061
/// - `mergeOps` contains operations created during the merge.
4162
/// - `replacements` contains the values that represents the result of the

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
384384
"::mlir::OpBuilder &":$b,
385385
"Location":$loc,
386386
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
387-
"::mlir::ArrayRef<int>":$reductionDim),
387+
"const ::mlir::SetVector<unsigned> &":$reductionDim),
388388
/*methodBody=*/"",
389389
/*defaultImplementation=*/[{
390390
return failure();
@@ -402,10 +402,11 @@ def PartialReductionOpInterface :
402402
/*args=*/(ins
403403
"::mlir::OpBuilder &":$b,
404404
"Location ":$loc,
405+
"::mlir::ReductionTilingStrategy":$tilingStrategy,
405406
"ValueRange":$init,
406407
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
407408
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
408-
"::mlir::ArrayRef<int>":$reductionDims),
409+
"const ::llvm::SetVector<unsigned> &":$reductionDims),
409410
/*methodBody=*/"",
410411
/*defaultImplementation=*/[{
411412
return failure();
@@ -423,7 +424,7 @@ def PartialReductionOpInterface :
423424
"::mlir::OpBuilder &":$b,
424425
"Location ":$loc,
425426
"ValueRange":$partialReduce,
426-
"::mlir::ArrayRef<int>":$reductionDim),
427+
"const ::mlir::SetVector<unsigned> &":$reductionDims),
427428
/*methodBody=*/"",
428429
/*defaultImplementation=*/[{
429430
return failure();
@@ -443,9 +444,9 @@ def PartialReductionOpInterface :
443444
"unsigned":$resultNumber,
444445
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
445446
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
447+
"const ::mlir::SetVector<unsigned> &":$reductionDims,
446448
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
447-
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
448-
"::mlir::ArrayRef<int>":$reductionDims),
449+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
449450
/*methodBody=*/"",
450451
/*defaultImplementation=*/[{
451452
return failure();

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2775,10 +2775,11 @@ void transform::TileReductionUsingForOp::build(
27752775
// TODO: support mixed static-dynamic (see TileUsingForallOp).
27762776
MLIRContext *ctx = builder.getContext();
27772777
auto opTy = transform::AnyOpType::get(ctx);
2778-
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2778+
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
27792779
build(builder, result,
27802780
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
27812781
/*target=*/target,
2782+
/*reduction_dims=*/nullptr,
27822783
/*tile_sizes=*/staticTileSizesAttr);
27832784
}
27842785

@@ -2794,12 +2795,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
27942795
target->getLoc(),
27952796
"Operation should implement PartialReductionOpInterface");
27962797
}
2797-
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2798-
rewriter, partialReductionOp,
2799-
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
28002798

2801-
if (failed(result))
2802-
return emitDefaultSilenceableFailure(target);
2799+
SmallVector<unsigned> reductionDims =
2800+
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2801+
if (reductionDims.empty()) {
2802+
for (auto [idx, iteratorType] :
2803+
llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
2804+
if (iteratorType == utils::IteratorType::reduction)
2805+
reductionDims.push_back(idx);
2806+
}
2807+
}
2808+
2809+
scf::SCFTilingOptions options;
2810+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
2811+
options.setReductionTilingStrategy(
2812+
ReductionTilingStrategy::PartialReductionOuterReduction);
2813+
options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
2814+
options.setReductionDims(reductionDims);
2815+
FailureOr<scf::SCFTilingResult> result =
2816+
scf::tileUsingSCF(rewriter, partialReductionOp, options);
2817+
2818+
if (failed(result)) {
2819+
return emitSilenceableFailure(getLoc(),
2820+
"failed to tile using partial reduction");
2821+
}
28032822
rewriter.replaceOp(target, result->replacements);
28042823
for (Value initValue : result->initialValues)
28052824
results.push_back(initValue.getDefiningOp());

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
109109
}
110110

111111
FailureOr<StaticContinuousTileSizeSpecification>
112-
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
113-
unsigned dimension,
112+
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
114113
unsigned targetSize) {
115114

116115
assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
183182

184183
// Find the trip count of the iteration space dimension for which the tile
185184
// sizes are computed.
186-
Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
187-
loopRanges[dimension].size);
185+
Value loopRange =
186+
getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
188187
ContinuousTileSizeSpecification spec;
189188

190189
// Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
633632
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
634633
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
635634
"many elements as number of threads");
636-
int reductionDim = static_cast<int>(redDims.front());
637635

638636
if (redDims.front() >= numThreads.size())
639637
return b.notifyMatchFailure(
640638
op, "reduction dimension must be mapped to threads");
641639

642640
// 1. Create the inital tensor value.
641+
unsigned reductionDim = redDims.front();
642+
SetVector<unsigned> reductionDims;
643+
reductionDims.insert(reductionDim);
643644
FailureOr<SmallVector<Value>> maybeInitTensors =
644645
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
645-
reductionDim);
646+
reductionDims);
646647
if (failed(maybeInitTensors))
647648
return b.notifyMatchFailure(
648649
op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
780781
// 7. Merge the partial reductions.
781782
b.setInsertionPointAfter(forallOp);
782783
FailureOr<MergeResult> mergeResult =
783-
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
784+
op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
784785
if (failed(mergeResult)) {
785786
return failure();
786787
}

0 commit comments

Comments
 (0)