Skip to content

[mlir][PartialReductionTilingInterface] Add support for ReductionTilingStrategy::PartialReductionOuterParallel in tileUsingSCF. #143988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2019,6 +2019,7 @@ def TileReductionUsingForallOp :

// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<I64ArrayAttr, "{}">:$reduction_dims,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
Expand All @@ -2036,10 +2037,11 @@ def TileReductionUsingForallOp :

let assemblyFormat = [{
$target
(`reduction_dims` `=` $reduction_dims^)?
`by`
(`num_threads` `=` $num_threads^)?
(`,` `tile_sizes` `=` $tile_sizes^)?
(`,` `mapping` `=` $mapping^)?
(`tile_sizes` `=` $tile_sizes^)?
(`mapping` `=` $mapping^)?
attr-dict
`:` functional-type(operands, results)
}];
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ SmallVector<OpFoldResult> getMixedValues(ArrayRef<int64_t> staticValues,
/// corresponding pair of arrays. This is the inverse function of
/// `getMixedValues`.
std::pair<SmallVector<int64_t>, SmallVector<Value>>
decomposeMixedValues(const SmallVectorImpl<OpFoldResult> &mixedValues);
decomposeMixedValues(ArrayRef<OpFoldResult> mixedValues);

/// Helper to sort `values` according to matching `keys`.
SmallVector<Value>
Expand Down
28 changes: 23 additions & 5 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -367,23 +367,28 @@ def PartialReductionOpInterface :
OpInterface<"PartialReductionOpInterface", [TilingInterface]> {
let description = [{
Interface for allowing operations to expose information needed to
tile reductions using partial reduction followed by merge. This is
complementary to TilingInterface to tile reductions.
tile reductions using partial reduction followed by merge. This
extends the `TilingInterface` to allow splitting a reduction
dimension into a parallel dimension and reduction dimension.
The materialized inter-tile loop could either be the reduction dimension
(i.e. `ReductionTilingStrategy::PartialReductionOuterReduction`) or
the parallel dimension (i.e
`ReductionTilingStrategy::PartialReductionOuterReduction`).
}];
let cppNamespace = "::mlir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Method to generate a tensor initalized with the identity value of the
operation reduction. The tensor shape is equal to operation result
reduction operator. The tensor shape is equal to operation result
shape with new dimension for each non zero tile size.
}],
/*retType=*/"::mlir::FailureOr<SmallVector<Value>>",
/*methodName=*/"generateInitialTensorForPartialReduction",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"Location":$loc,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$tileSizes,
"const ::mlir::SetVector<unsigned> &":$reductionDims),
/*methodBody=*/"",
/*defaultImplementation=*/[{
Expand All @@ -396,6 +401,11 @@ def PartialReductionOpInterface :
reduction dimension are converted to parallel dimensions with a size
less or equal to the tile size. This is meant to be used with
`mergeReductions` method which will combine the partial reductions.
The method recieves the `offset` and `sizes` for all iteration space
dimensions, as well as the iteration number of the tiled reduction
dimensions (which is the induction variable of the inter-tile loop
for the reduction dimension divided by the step of the loop) in
`splitReductionIvs`.
}],
/*retType=*/"::mlir::FailureOr<TilingResult>",
/*methodName=*/"tileToPartialReduction",
Expand All @@ -406,7 +416,8 @@ def PartialReductionOpInterface :
"ValueRange":$init,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes,
"const ::llvm::SetVector<unsigned> &":$reductionDims),
"const ::llvm::SetVector<unsigned> &":$reductionDims,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
Expand Down Expand Up @@ -436,15 +447,22 @@ def PartialReductionOpInterface :
the tiled operation. This is same as
TilingInterface:::getResultTilePosition, but determines the result
tile position for partial reduction.
The method recieves the `offset` and `sizes` for all iteration space
dimensions, as well as the iteration number of the tiled reduction
dimensions (which is the induction variable of the inter-tile loop
for the reduction dimension divided by the tile size specified) in
`splitReductionIvs`.
}],
/*retType=*/"::llvm::LogicalResult",
/*methodName=*/"getPartialResultTilePosition",
/*args=*/(ins
"::mlir::OpBuilder &":$b,
"unsigned":$resultNumber,
"ReductionTilingStrategy":$tilingStrategy,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
"const ::mlir::SetVector<unsigned> &":$reductionDims,
"::mlir::ArrayRef<::mlir::OpFoldResult>":$splitReductionIvs,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes),
/*methodBody=*/"",
Expand Down
37 changes: 30 additions & 7 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3022,6 +3022,7 @@ void transform::TileReductionUsingForallOp::build(
build(builder, result,
/*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
/*target=*/target,
/*reduction_dims=*/{},
/*num_threads=*/staticNumThreadsAttr,
/*tile_sizes=*/staticTileSizesAttr,
/*mapping=*/mapping);
Expand All @@ -3036,23 +3037,45 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
SmallVector<OpFoldResult> tileSizes =
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
FailureOr<linalg::ForallReductionTilingResult> result =
linalg::tileReductionUsingForall(
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
numThreads, tileSizes, getMapping());

scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setReductionTilingStrategy(
ReductionTilingStrategy::PartialReductionOuterParallel);
if (!getNumThreads().empty()) {
options.setNumThreads(numThreads);
} else {
options.setTileSizes(tileSizes);
}
if (auto mapping = getMapping()) {
options.setMapping(mapping.value().getValue());
}
SmallVector<unsigned> reductionDims =
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
if (reductionDims.empty()) {
for (auto [idx, iteratorType] :
llvm::enumerate(target.getIteratorTypesArray())) {
if (iteratorType == utils::IteratorType::reduction)
reductionDims.push_back(idx);
}
}
options.setReductionDims(reductionDims);
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
rewriter, cast<TilingInterface>(target.getOperation()), options);

if (failed(result)) {
auto diag = emitSilenceableError() << "could not tile reduction";
diag.attachNote(target.getLoc()) << "target operation";
return diag;
}
rewriter.replaceOp(target, result->replacements);

for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
for (auto parallelTiledOp : result->parallelTiledOps)
for (auto parallelTiledOp : result->tiledOps)
results.push_back(parallelTiledOp);
for (auto mergeOp : result->mergeOps)
results.push_back(mergeOp);
results.push_back(result->loops);
results.push_back(result->loops.front());
return DiagnosedSilenceableFailure::success();
}

Expand Down
Loading