Skip to content

Commit 17dca46

Browse files
committed
[mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface
1 parent 1e7d047 commit 17dca46

File tree

10 files changed

+199
-126
lines changed

10 files changed

+199
-126
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
16811681
// TODO: support mixed static-dynamic (see TileUsingForallOp).
16821682
let arguments = (ins TransformHandleTypeInterface:$target,
16831683
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
1684-
let results = (outs TransformHandleTypeInterface:$fill_op,
1684+
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
16851685
TransformHandleTypeInterface:$split_linalg_op,
16861686
TransformHandleTypeInterface:$combining_linalg_op,
16871687
TransformHandleTypeInterface:$for_op);
@@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
17871787
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
17881788
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
17891789
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
1790-
let results = (outs TransformHandleTypeInterface:$fill_op,
1790+
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
17911791
TransformHandleTypeInterface:$split_linalg_op,
17921792
TransformHandleTypeInterface:$combining_linalg_op,
17931793
TransformHandleTypeInterface:$forall_op);

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
876876
Operation *parallelTiledOp;
877877
/// The final reduction operation merging all the partial reductions.
878878
Operation *mergeOp;
879-
/// The op initializing the tensor used for partial reductions.
880-
Operation *initialOp;
879+
/// Initial values used for partial reductions.
880+
SmallVector<Value> initialValues;
881881
/// The `scf.forall` operation that iterate over the tiles.
882882
scf::ForallOp loops;
883883
};

mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
250250
Operation *parallelTiledOp;
251251
/// The final reduction operation merging all the partial reductions.
252252
Operation *mergeOp;
253-
/// Initial op
254-
Operation *initialOp;
253+
/// Initial values used for reduction.
254+
SmallVector<Value> initialValues;
255255
/// The loop operations that iterate over the tiles.
256256
SmallVector<LoopLikeOpInterface> loops;
257257
};

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,12 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
170170
operation reduction. The tensor shape is equal to operation result
171171
shape with new dimension for each non zero tile size.
172172
}],
173-
/*retType=*/"FailureOr<Operation*>",
173+
/*retType=*/"FailureOr<Value>",
174174
/*methodName=*/"generateInitialTensorForPartialReduction",
175175
/*args=*/(ins
176176
"OpBuilder &":$b,
177-
"Location ":$loc,
177+
"Location":$loc,
178+
"int64_t":$resultNumber,
178179
"ArrayRef<OpFoldResult>":$sizes,
179180
"ArrayRef<int>":$reductionDim),
180181
/*methodBody=*/"",

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
25232523

25242524
if (failed(result))
25252525
return emitDefaultSilenceableFailure(target);
2526-
results.push_back(result->initialOp);
2526+
for (Value initValue : result->initialValues)
2527+
results.push_back(initValue.getDefiningOp());
25272528
results.push_back(result->parallelTiledOp);
25282529
results.push_back(result->mergeOp);
25292530
results.push_back(result->loops.front());
@@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
25742575
diag.attachNote(target.getLoc()) << "target operation";
25752576
return diag;
25762577
}
2577-
results.push_back(result->initialOp);
2578+
for (Value initValue : result->initialValues)
2579+
results.push_back(initValue.getDefiningOp());
25782580
results.push_back(result->parallelTiledOp);
25792581
results.push_back(result->mergeOp);
25802582
results.push_back(result->loops);

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
155155
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
156156
PartialReductionOpInterface partialReductionIface =
157157
llvm::cast<PartialReductionOpInterface>(op.getOperation());
158-
FailureOr<Operation *> reductionNeutralTensorOp =
158+
assert(op->getNumResults() == 1 && "Multiple results not supported.");
159+
FailureOr<Value> reductionNeutralTensor =
159160
partialReductionIface.generateInitialTensorForPartialReduction(
160-
builder, builder.getLoc(), shape, {});
161-
assert(succeeded(reductionNeutralTensorOp));
162-
builder.create<scf::YieldOp>(
163-
reductionNeutralTensorOp.value()->getResult(0));
161+
builder, builder.getLoc(), 0, shape, {});
162+
assert(succeeded(reductionNeutralTensor));
163+
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
164164
}
165165
return ifOp.getResult(0);
166166
}
@@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
173173
ImplicitLocOpBuilder &builder) {
174174
// TODO: add support for multiple destination passing style initial value
175175
// operands.
176-
// PartialReductionOpInterface::generateInitialTensorForPartialReduction
177-
// needs to also support multiple DPS initial operands.
176+
assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
178177
SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
179178
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
180179
Value spmdizedInitOperand =

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -657,9 +657,9 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
657657
auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation());
658658

659659
// Ops implementing PartialReductionOpInterface are not necessarily expected
660-
// to implement TilingInterface.. This cast is unsafe atm.
660+
// to implement DestinationStyleOpInterface. This cast is unsafe atm.
661661
// TODO: proper core mechanism to tie interfaces together.
662-
// TODO: this function requires a pair of interfaces ..
662+
// TODO: this function requires a pair of interfaces.
663663
auto destinationStyleOp =
664664
dyn_cast<DestinationStyleOpInterface>(op.getOperation());
665665
if (!destinationStyleOp)
@@ -671,9 +671,6 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
671671
return b.notifyMatchFailure(op, "not a linalg op");
672672

673673
SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b);
674-
if (op->getNumResults() != 1)
675-
return b.notifyMatchFailure(
676-
op, "don't support ops with multiple results for now");
677674

678675
SmallVector<utils::IteratorType> iterators =
679676
tilingInterfaceOp.getLoopIteratorTypes();
@@ -692,12 +689,17 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
692689
op, "reduction dimension must be mapped to threads");
693690

694691
// 1. Create the inital tensor value.
695-
FailureOr<Operation *> identityTensor =
696-
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
697-
reductionDim);
698-
if (failed(identityTensor))
699-
return b.notifyMatchFailure(op,
700-
"cannot create a tensor of identity value.");
692+
SmallVector<Value> initTensors;
693+
initTensors.reserve(op->getNumResults());
694+
for (int idx : llvm::seq<int>(0, op->getNumResults())) {
695+
FailureOr<Value> initValue = op.generateInitialTensorForPartialReduction(
696+
b, loc, idx, numThreads, reductionDim);
697+
if (failed(initValue))
698+
return b.notifyMatchFailure(
699+
op, "cannot create a tensor of identity value for result " +
700+
std::to_string(idx));
701+
initTensors.push_back(initValue.value());
702+
}
701703

702704
// Gather destination tensors.
703705
SmallVector<Value> dest;
@@ -715,8 +717,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
715717

716718
// 2. Create the ForallOp with an empty region.
717719
scf::ForallOp forallOp = b.create<scf::ForallOp>(
718-
loc, getAsOpFoldResult(materializedNonZeroNumThreads),
719-
(*identityTensor)->getResults(), mapping);
720+
loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
721+
mapping);
720722

721723
// 3. Calculate the tile offsets and sizes for the subsequent loop that will
722724
// be nested under `forallOp`.
@@ -726,7 +728,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
726728
/*nominalTileSizes=*/std::nullopt, tiledOffsets,
727729
tiledSizes);
728730

729-
// 4. Clone the tileable op and update its destination operands to use the
731+
// 4b. Clone the tileable op and update its destination operands to use the
730732
// output bbArgs of the ForallOp.
731733
SmallVector<Value> tilingResults;
732734
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
@@ -838,7 +840,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
838840

839841
// 8. Return.
840842
ForallReductionTilingResult results;
841-
results.initialOp = *identityTensor;
843+
results.initialValues = initTensors;
842844
results.loops = forallOp;
843845
results.parallelTiledOp = tiledOp;
844846
results.mergeOp = mergeOp;

0 commit comments

Comments
 (0)