Skip to content

Commit 7181785

Browse files
[mlir][PartialReductionTilingInterface] Generalize implementation of tileUsingSCF for ReductionTilingStrategy::PartialOuterReduction. (#143467)
This is a precursor to generalizing the `tileUsingSCF` to handle `ReductionTilingStrategy::PartialOuterParallel` strategy. This change itself is generalizing/refactoring the current implementation that supports only `ReductionTilingStrategy::PartialOuterReduction`. Changes in this PR - Move the `ReductionTilingStrategy` enum out of `scf::SCFTilingOptions` and make them visible to `TilingInterface`. - `PartialTilingInterface` changes - Pass the `tilingStrategy` used for partial reduction to `tileToPartialReduction`. - Pass the reduction dimension along as `const llvm::SetVector<unsigned> &`. - Allow `scf::SCFTilingOptions` to set the reduction dimensions that are to be tiled. - Change `structured.tiled_reduction_using_for` to allow specification of the reduction dimensions to be partially tiled. Signed-off-by: MaheshRavishankar <[email protected]>
1 parent e80acd4 commit 7181785

File tree

9 files changed

+433
-251
lines changed

9 files changed

+433
-251
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
18591859
- the result-combining op,
18601860
- the parent `for` op.
18611861

1862+
The `reduction_dims` can be used to specify the subset of reduction dimensions
1863+
of the operation to tile. If left unspecified, all reduction dimensions are
1864+
tiled.
1865+
18621866
#### Example:
18631867

18641868
```
@@ -1909,7 +1913,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
19091913

19101914
// TODO: support mixed static-dynamic (see TileUsingForallOp).
19111915
let arguments = (ins TransformHandleTypeInterface:$target,
1912-
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1916+
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
1917+
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
19131918
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
19141919
TransformHandleTypeInterface:$split_op,
19151920
TransformHandleTypeInterface:$combining_op,
@@ -1922,6 +1927,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
19221927

19231928
let assemblyFormat = [{
19241929
$target
1930+
(`reduction_dims` `=` $reduction_dims^)?
19251931
`by` `tile_sizes` `=` $tile_sizes
19261932
attr-dict
19271933
`:` functional-type(operands, results)

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,21 @@ struct SCFTilingOptions {
8585
return *this;
8686
}
8787

88+
/// Specify mapping of loops to devices. This is only respected when the loop
89+
/// constructs support such a mapping (like `scf.forall`). Will be ignored
90+
/// when using loop constructs that dont support such a mapping (like
91+
/// `scf.for`)
92+
SmallVector<Attribute> mappingVector = {};
93+
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
94+
mappingVector = llvm::to_vector(mapping);
95+
return *this;
96+
}
97+
98+
//-------------------------------------------------------------------------//
99+
// Options related reduction tiling
100+
//-------------------------------------------------------------------------//
101+
88102
/// Specify how reduction dimensions should be tiled.
89-
///
90-
/// Tiling can be thought of as splitting a dimension into 2 and materializing
91-
/// the outer dimension as a loop:
92-
///
93-
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
94-
///
95-
/// For parallel dimensions, the split can only happen in one way, with both
96-
/// dimensions being parallel. For reduction dimensions however, there is a
97-
/// choice in how we split the reduction dimension. This enum exposes this
98-
/// choice.
99-
enum class ReductionTilingStrategy {
100-
// [reduction] -> [reduction1, reduction2]
101-
// -> loop[reduction1] { [reduction2] }
102-
FullReduction,
103-
// [reduction] -> [reduction1, parallel2]
104-
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
105-
PartialReductionOuterReduction,
106-
// [reduction] -> [parallel1, reduction2]
107-
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
108-
PartialReductionOuterParallel
109-
};
110103
ReductionTilingStrategy reductionStrategy =
111104
ReductionTilingStrategy::FullReduction;
112105
SCFTilingOptions &
@@ -115,13 +108,13 @@ struct SCFTilingOptions {
115108
return *this;
116109
}
117110

118-
/// Specify mapping of loops to devices. This is only respected when the loop
119-
/// constructs support such a mapping (like `scf.forall`). Will be ignored
120-
/// when using loop constructs that dont support such a mapping (like
121-
/// `scf.for`)
122-
SmallVector<Attribute> mappingVector = {};
123-
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
124-
mappingVector = llvm::to_vector(mapping);
111+
/// Specify the reduction dimensions to be tiled. Note that this needs to be
112+
/// specified. If left unspecified, then none of the reduction dimensions are
113+
/// tiled.
114+
SetVector<unsigned> reductionDims;
115+
SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
116+
reductionDims.clear();
117+
reductionDims.insert(dims.begin(), dims.end());
125118
return *this;
126119
}
127120
};

mlir/include/mlir/Interfaces/TilingInterface.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,27 @@ struct TilingResult {
3636
SmallVector<Operation *> generatedSlices;
3737
};
3838

39+
/// Tiling can be thought of as splitting a dimension into 2 and
40+
/// materializing the outer dimension as a loop:
41+
///
42+
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
43+
///
44+
/// For parallel dimensions, the split can only happen in one way, with both
45+
/// dimensions being parallel. For reduction dimensions however, there is a
46+
/// choice in how we split the reduction dimension. This enum exposes this
47+
/// choice.
48+
enum class ReductionTilingStrategy {
49+
// [reduction] -> [reduction1, reduction2]
50+
// -> loop[reduction1] { [reduction2] }
51+
FullReduction,
52+
// [reduction] -> [reduction1, parallel2]
53+
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
54+
PartialReductionOuterReduction,
55+
// [reduction] -> [parallel1, reduction2]
56+
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
57+
PartialReductionOuterParallel
58+
};
59+
3960
/// Container for the result of merge operation of tiling.
4061
/// - `mergeOps` contains operations created during the merge.
4162
/// - `replacements` contains the values that represents the result of the

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def PartialReductionOpInterface :
384384
"::mlir::OpBuilder &":$b,
385385
"Location":$loc,
386386
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
387-
"::mlir::ArrayRef<int>":$reductionDim),
387+
"const ::mlir::SetVector<unsigned> &":$reductionDims),
388388
/*methodBody=*/"",
389389
/*defaultImplementation=*/[{
390390
return failure();
@@ -402,10 +402,11 @@ def PartialReductionOpInterface :
402402
/*args=*/(ins
403403
"::mlir::OpBuilder &":$b,
404404
"Location ":$loc,
405+
"::mlir::ReductionTilingStrategy":$tilingStrategy,
405406
"ValueRange":$init,
406407
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
407408
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
408-
"::mlir::ArrayRef<int>":$reductionDims),
409+
"const ::llvm::SetVector<unsigned> &":$reductionDims),
409410
/*methodBody=*/"",
410411
/*defaultImplementation=*/[{
411412
return failure();
@@ -423,7 +424,7 @@ def PartialReductionOpInterface :
423424
"::mlir::OpBuilder &":$b,
424425
"Location ":$loc,
425426
"ValueRange":$partialReduce,
426-
"::mlir::ArrayRef<int>":$reductionDim),
427+
"const ::mlir::SetVector<unsigned> &":$reductionDims),
427428
/*methodBody=*/"",
428429
/*defaultImplementation=*/[{
429430
return failure();
@@ -443,9 +444,9 @@ def PartialReductionOpInterface :
443444
"unsigned":$resultNumber,
444445
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
445446
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
447+
"const ::mlir::SetVector<unsigned> &":$reductionDims,
446448
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
447-
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
448-
"::mlir::ArrayRef<int>":$reductionDims),
449+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
449450
/*methodBody=*/"",
450451
/*defaultImplementation=*/[{
451452
return failure();

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2947,10 +2947,11 @@ void transform::TileReductionUsingForOp::build(
29472947
// TODO: support mixed static-dynamic (see TileUsingForallOp).
29482948
MLIRContext *ctx = builder.getContext();
29492949
auto opTy = transform::AnyOpType::get(ctx);
2950-
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
2950+
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
29512951
build(builder, result,
29522952
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
29532953
/*target=*/target,
2954+
/*reduction_dims=*/nullptr,
29542955
/*tile_sizes=*/staticTileSizesAttr);
29552956
}
29562957

@@ -2966,12 +2967,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
29662967
target->getLoc(),
29672968
"Operation should implement PartialReductionOpInterface");
29682969
}
2969-
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
2970-
rewriter, partialReductionOp,
2971-
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
29722970

2973-
if (failed(result))
2974-
return emitDefaultSilenceableFailure(target);
2971+
SmallVector<unsigned> reductionDims =
2972+
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
2973+
if (reductionDims.empty()) {
2974+
for (auto [idx, iteratorType] :
2975+
llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
2976+
if (iteratorType == utils::IteratorType::reduction)
2977+
reductionDims.push_back(idx);
2978+
}
2979+
}
2980+
2981+
scf::SCFTilingOptions options;
2982+
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
2983+
options.setReductionTilingStrategy(
2984+
ReductionTilingStrategy::PartialReductionOuterReduction);
2985+
options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
2986+
options.setReductionDims(reductionDims);
2987+
FailureOr<scf::SCFTilingResult> result =
2988+
scf::tileUsingSCF(rewriter, partialReductionOp, options);
2989+
2990+
if (failed(result)) {
2991+
return emitSilenceableFailure(getLoc(),
2992+
"failed to tile using partial reduction");
2993+
}
29752994
rewriter.replaceOp(target, result->replacements);
29762995
for (Value initValue : result->initialValues)
29772996
results.push_back(initValue.getDefiningOp());

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
109109
}
110110

111111
FailureOr<StaticContinuousTileSizeSpecification>
112-
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
113-
unsigned dimension,
112+
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
114113
unsigned targetSize) {
115114

116115
assert(!op.hasDynamicShape() &&
@@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,
183182

184183
// Find the trip count of the iteration space dimension for which the tile
185184
// sizes are computed.
186-
Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
187-
loopRanges[dimension].size);
185+
Value loopRange =
186+
getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
188187
ContinuousTileSizeSpecification spec;
189188

190189
// Compute the tile sizes and the respective numbers of tiles.
@@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
633632
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
634633
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
635634
"many elements as number of threads");
636-
int reductionDim = static_cast<int>(redDims.front());
637635

638636
if (redDims.front() >= numThreads.size())
639637
return b.notifyMatchFailure(
640638
op, "reduction dimension must be mapped to threads");
641639

642640
// 1. Create the inital tensor value.
641+
unsigned reductionDim = redDims.front();
642+
SetVector<unsigned> reductionDims;
643+
reductionDims.insert(reductionDim);
643644
FailureOr<SmallVector<Value>> maybeInitTensors =
644645
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
645-
reductionDim);
646+
reductionDims);
646647
if (failed(maybeInitTensors))
647648
return b.notifyMatchFailure(
648649
op, "Failed to create inital tensors for partial reduction");
@@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
780781
// 7. Merge the partial reductions.
781782
b.setInsertionPointAfter(forallOp);
782783
FailureOr<MergeResult> mergeResult =
783-
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
784+
op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
784785
if (failed(mergeResult)) {
785786
return failure();
786787
}

0 commit comments

Comments
 (0)