Skip to content

[mlir][PartialReductionTilingInterface] Generalize implementation of tileUsingSCF for ReductionTilingStrategy::PartialOuterReduction. #143467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,10 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
- the result-combining op,
- the parent `for` op.

The `reduction_dims` can be used to specify the subset of reduction dimensions
of the operation to tile. If left unspecified, all reduction dimensions are
tiled.

#### Example:

```
Expand Down Expand Up @@ -1900,7 +1904,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u

// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_op,
TransformHandleTypeInterface:$combining_op,
Expand All @@ -1913,6 +1918,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u

let assemblyFormat = [{
$target
(`reduction_dims` `=` $reduction_dims^)?
`by` `tile_sizes` `=` $tile_sizes
attr-dict
`:` functional-type(operands, results)
Expand Down
49 changes: 21 additions & 28 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,28 +85,21 @@ struct SCFTilingOptions {
return *this;
}

/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
mappingVector = llvm::to_vector(mapping);
return *this;
}

//-------------------------------------------------------------------------//
// Options related reduction tiling
//-------------------------------------------------------------------------//

/// Specify how reduction dimensions should be tiled.
///
/// Tiling can be thought of as splitting a dimension into 2 and materializing
/// the outer dimension as a loop:
///
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
///
/// For parallel dimensions, the split can only happen in one way, with both
/// dimensions being parallel. For reduction dimensions however, there is a
/// choice in how we split the reduction dimension. This enum exposes this
/// choice.
enum class ReductionTilingStrategy {
// [reduction] -> [reduction1, reduction2]
// -> loop[reduction1] { [reduction2] }
FullReduction,
// [reduction] -> [reduction1, parallel2]
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
PartialReductionOuterReduction,
// [reduction] -> [parallel1, reduction2]
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
PartialReductionOuterParallel
};
ReductionTilingStrategy reductionStrategy =
ReductionTilingStrategy::FullReduction;
SCFTilingOptions &
Expand All @@ -115,13 +108,13 @@ struct SCFTilingOptions {
return *this;
}

/// Specify mapping of loops to devices. This is only respected when the loop
/// constructs support such a mapping (like `scf.forall`). Will be ignored
/// when using loop constructs that dont support such a mapping (like
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
mappingVector = llvm::to_vector(mapping);
/// Specify the reduction dimensions to be tiled. Note that this needs to be
/// specified. If left unspecified, then none of the reduction dimensions are
/// tiled.
SetVector<unsigned> reductionDims;
SCFTilingOptions &setReductionDims(ArrayRef<unsigned> dims) {
reductionDims.clear();
reductionDims.insert(dims.begin(), dims.end());
return *this;
}
};
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Interfaces/TilingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,27 @@ struct TilingResult {
SmallVector<Operation *> generatedSlices;
};

/// Tiling can be thought of as splitting a dimension into 2 and
/// materializing the outer dimension as a loop:
///
/// op[original] -> op[original / x, x] -> loop[original] { op[x] }
///
/// For parallel dimensions, the split can only happen in one way, with both
/// dimensions being parallel. For reduction dimensions however, there is a
/// choice in how we split the reduction dimension. This enum exposes this
/// choice.
enum class ReductionTilingStrategy {
// [reduction] -> [reduction1, reduction2]
// -> loop[reduction1] { [reduction2] }
FullReduction,
// [reduction] -> [reduction1, parallel2]
// -> loop[reduction1] { [parallel2] }; merge[reduction1]
PartialReductionOuterReduction,
// [reduction] -> [parallel1, reduction2]
// -> loop[parallel1] { [reduction2] }; merge[parallel1]
PartialReductionOuterParallel
};

/// Container for the result of merge operation of tiling.
/// - `mergeOps` contains operations created during the merge.
/// - `replacements` contains the values that represents the result of the
Expand Down
11 changes: 6 additions & 5 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def PartialReductionOpInterface :
"::mlir::OpBuilder &":$b,
"Location":$loc,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
"::mlir::ArrayRef<int>":$reductionDim),
"const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand All @@ -402,10 +402,11 @@ def PartialReductionOpInterface :
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"Location ":$loc,
"::mlir::ReductionTilingStrategy":$tilingStrategy,
"ValueRange":$init,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
"::mlir::ArrayRef<int>":$reductionDims),
"const ::llvm::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand All @@ -423,7 +424,7 @@ def PartialReductionOpInterface :
"::mlir::OpBuilder &":$b,
"Location ":$loc,
"ValueRange":$partialReduce,
"::mlir::ArrayRef<int>":$reductionDim),
"const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand All @@ -443,9 +444,9 @@ def PartialReductionOpInterface :
"unsigned":$resultNumber,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
"const ::mlir::SetVector<unsigned> &":$reductionDims,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
"::mlir::ArrayRef<int>":$reductionDims),
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand Down
31 changes: 25 additions & 6 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2966,10 +2966,11 @@ void transform::TileReductionUsingForOp::build(
// TODO: support mixed static-dynamic (see TileUsingForallOp).
MLIRContext *ctx = builder.getContext();
auto opTy = transform::AnyOpType::get(ctx);
auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
auto staticTileSizesAttr = builder.getI64ArrayAttr(staticTileSizes);
build(builder, result,
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
/*target=*/target,
/*reduction_dims=*/nullptr,
/*tile_sizes=*/staticTileSizesAttr);
}

Expand All @@ -2985,12 +2986,30 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
target->getLoc(),
"Operation should implement PartialReductionOpInterface");
}
FailureOr<scf::SCFTilingResult> result = scf::tileReductionUsingScf(
rewriter, partialReductionOp,
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));

if (failed(result))
return emitDefaultSilenceableFailure(target);
SmallVector<unsigned> reductionDims =
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
if (reductionDims.empty()) {
for (auto [idx, iteratorType] :
llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
}

scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
options.setReductionTilingStrategy(
ReductionTilingStrategy::PartialReductionOuterReduction);
options.setTileSizes(getAsOpFoldResult(getTileSizesAttr()));
options.setReductionDims(reductionDims);
FailureOr<scf::SCFTilingResult> result =
scf::tileUsingSCF(rewriter, partialReductionOp, options);

if (failed(result)) {
return emitSilenceableFailure(getLoc(),
"failed to tile using partial reduction");
}
rewriter.replaceOp(target, result->replacements);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
Expand Down
15 changes: 8 additions & 7 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b,
}

FailureOr<StaticContinuousTileSizeSpecification>
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op,
unsigned dimension,
mlir::linalg::computeStaticContinuousTileSizes(LinalgOp op, unsigned dimension,
unsigned targetSize) {

assert(!op.hasDynamicShape() &&
Expand Down Expand Up @@ -183,8 +182,8 @@ mlir::linalg::computeContinuousTileSizes(OpBuilder &builder, TilingInterface op,

// Find the trip count of the iteration space dimension for which the tile
// sizes are computed.
Value loopRange = getValueOrCreateConstantIndexOp(b, loc,
loopRanges[dimension].size);
Value loopRange =
getValueOrCreateConstantIndexOp(b, loc, loopRanges[dimension].size);
ContinuousTileSizeSpecification spec;

// Compute the tile sizes and the respective numbers of tiles.
Expand Down Expand Up @@ -633,16 +632,18 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
if (!tileSizes.empty() && tileSizes.size() != numThreads.size())
return b.notifyMatchFailure(op, "if tile sizes are present it must have as "
"many elements as number of threads");
int reductionDim = static_cast<int>(redDims.front());

if (redDims.front() >= numThreads.size())
return b.notifyMatchFailure(
op, "reduction dimension must be mapped to threads");

// 1. Create the inital tensor value.
unsigned reductionDim = redDims.front();
SetVector<unsigned> reductionDims;
reductionDims.insert(reductionDim);
FailureOr<SmallVector<Value>> maybeInitTensors =
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
reductionDim);
reductionDims);
if (failed(maybeInitTensors))
return b.notifyMatchFailure(
op, "Failed to create inital tensors for partial reduction");
Expand Down Expand Up @@ -780,7 +781,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
// 7. Merge the partial reductions.
b.setInsertionPointAfter(forallOp);
FailureOr<MergeResult> mergeResult =
op.mergeReductions(b, loc, forallOp->getResults(), reductionDim);
op.mergeReductions(b, loc, forallOp->getResults(), reductionDims);
if (failed(mergeResult)) {
return failure();
}
Expand Down
Loading
Loading