Skip to content

Commit 9329b20

Browse files
authored
[mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface (#92624)
This patch adds support for reducing operations with multiple results using PartialReductionOpInterface. Also adds an implementation of PartialReductionOpInterface for multiple results for linalg.generic.
1 parent 0370beb commit 9329b20

File tree

10 files changed

+236
-149
lines changed

10 files changed

+236
-149
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
16811681
// TODO: support mixed static-dynamic (see TileUsingForallOp).
16821682
let arguments = (ins TransformHandleTypeInterface:$target,
16831683
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1684-
let results = (outs TransformHandleTypeInterface:$fill_op,
1684+
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
16851685
TransformHandleTypeInterface:$split_linalg_op,
16861686
TransformHandleTypeInterface:$combining_linalg_op,
16871687
TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
17871787
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
17881788
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
17891789
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
1790-
let results = (outs TransformHandleTypeInterface:$fill_op,
1790+
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
17911791
TransformHandleTypeInterface:$split_linalg_op,
17921792
TransformHandleTypeInterface:$combining_linalg_op,
17931793
TransformHandleTypeInterface:$forall_op);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
876876
Operation *parallelTiledOp;
877877
/// The final reduction operation merging all the partial reductions.
878878
Operation *mergeOp;
879-
/// The op initializing the tensor used for partial reductions.
880-
Operation *initialOp;
879+
/// Initial values used for partial reductions.
880+
SmallVector<Value> initialValues;
881881
/// The `scf.forall` operation that iterate over the tiles.
882882
scf::ForallOp loops;
883883
};

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
250250
Operation *parallelTiledOp;
251251
/// The final reduction operation merging all the partial reductions.
252252
Operation *mergeOp;
253-
/// Initial op
254-
Operation *initialOp;
253+
/// Initial values used for reduction.
254+
SmallVector<Value> initialValues;
255255
/// The loop operations that iterate over the tiles.
256256
SmallVector<LoopLikeOpInterface> loops;
257257
};

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,11 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
170170
operation reduction. The tensor shape is equal to operation result
171171
shape with new dimension for each non zero tile size.
172172
}],
173-
/*retType=*/"FailureOr<Operation*>",
173+
/*retType=*/"FailureOr<SmallVector<Value>>",
174174
/*methodName=*/"generateInitialTensorForPartialReduction",
175175
/*args=*/(ins
176176
"OpBuilder &":$b,
177-
"Location ":$loc,
177+
"Location":$loc,
178178
"ArrayRef<OpFoldResult>":$sizes,
179179
"ArrayRef<int>":$reductionDim),
180180
/*methodBody=*/"",

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
25232523

25242524
if (failed(result))
25252525
return emitDefaultSilenceableFailure(target);
2526-
results.push_back(result->initialOp);
2526+
for (Value initValue : result->initialValues)
2527+
results.push_back(initValue.getDefiningOp());
25272528
results.push_back(result->parallelTiledOp);
25282529
results.push_back(result->mergeOp);
25292530
results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
25742575
diag.attachNote(target.getLoc()) << "target operation";
25752576
return diag;
25762577
}
2577-
results.push_back(result->initialOp);
2578+
for (Value initValue : result->initialValues)
2579+
results.push_back(initValue.getDefiningOp());
25782580
results.push_back(result->parallelTiledOp);
25792581
results.push_back(result->mergeOp);
25802582
results.push_back(result->loops);

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
155155
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156156
PartialReductionOpInterface partialReductionIface =
157157
llvm::cast<PartialReductionOpInterface>(op.getOperation());
158-
FailureOr<Operation *> reductionNeutralTensorOp =
158+
assert(op->getNumResults() == 1 && "Multiple results not supported.");
159+
FailureOr<SmallVector<Value>> reductionNeutralTensor =
159160
partialReductionIface.generateInitialTensorForPartialReduction(
160161
builder, builder.getLoc(), shape, {});
161-
assert(succeeded(reductionNeutralTensorOp));
162-
builder.create<scf::YieldOp>(
163-
reductionNeutralTensorOp.value()->getResult(0));
162+
assert(succeeded(reductionNeutralTensor));
163+
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
164164
}
165165
return ifOp.getResult(0);
166166
}
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
173173
ImplicitLocOpBuilder &builder) {
174174
// TODO: add support for multiple destination passing style initial value
175175
// operands.
176-
// PartialReductionOpInterface::generateInitialTensorForPartialReduction
177-
// needs to also support multiple DPS initial operands.
176+
assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
178177
SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
179178
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
180179
Value spmdizedInitOperand =

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -692,12 +692,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
692692
op, "reduction dimension must be mapped to threads");
693693

694694
// 1. Create the inital tensor value.
695-
FailureOr<Operation *> identityTensor =
695+
FailureOr<SmallVector<Value>> maybeInitTensors =
696696
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
697697
reductionDim);
698-
if (failed(identityTensor))
699-
return b.notifyMatchFailure(op,
700-
"cannot create a tensor of identity value.");
698+
if (failed(maybeInitTensors))
699+
return b.notifyMatchFailure(
700+
op, "Failed to create inital tensors for partial reduction");
701+
SmallVector<Value> &initTensors = maybeInitTensors.value();
701702

702703
// Gather destination tensors.
703704
SmallVector<Value> dest;
@@ -715,8 +716,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
715716

716717
// 2. Create the ForallOp with an empty region.
717718
scf::ForallOp forallOp = b.create<scf::ForallOp>(
718-
loc, getAsOpFoldResult(materializedNonZeroNumThreads),
719-
(*identityTensor)->getResults(), mapping);
719+
loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
720+
mapping);
720721

721722
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
722723
// be nested under `forallOp`.
@@ -726,7 +727,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
726727
/*nominalTileSizes=*/std::nullopt, tiledOffsets,
727728
tiledSizes);
728729

729-
// 4. Clone the tileable op and update its destination operands to use the
730+
// 4b. Clone the tileable op and update its destination operands to use the
730731
// output bbArgs of the ForallOp.
731732
SmallVector<Value> tilingResults;
732733
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +839,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
838839

839840
// 8. Return.
840841
ForallReductionTilingResult results;
841-
results.initialOp = *identityTensor;
842+
results.initialValues = initTensors;
842843
results.loops = forallOp;
843844
results.parallelTiledOp = tiledOp;
844845
results.mergeOp = mergeOp;

0 commit comments

Comments
 (0)