Skip to content

[mlir][TilingInterface] Allow multiple results in PartialReductionOpInterface #92624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1681,7 +1681,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
// TODO: support mixed static-dynamic (see TileUsingForallOp).
let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
let results = (outs TransformHandleTypeInterface:$fill_op,
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_linalg_op,
TransformHandleTypeInterface:$combining_linalg_op,
TransformHandleTypeInterface:$for_op);
Expand Down Expand Up @@ -1787,7 +1787,7 @@ def TileReductionUsingForallOp :
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$num_threads,
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
let results = (outs TransformHandleTypeInterface:$fill_op,
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
TransformHandleTypeInterface:$split_linalg_op,
TransformHandleTypeInterface:$combining_linalg_op,
TransformHandleTypeInterface:$forall_op);
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -876,8 +876,8 @@ struct ForallReductionTilingResult {
Operation *parallelTiledOp;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
/// The op initializing the tensor used for partial reductions.
Operation *initialOp;
/// Initial values used for partial reductions.
SmallVector<Value> initialValues;
/// The `scf.forall` operation that iterate over the tiles.
scf::ForallOp loops;
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ struct SCFReductionTilingResult {
Operation *parallelTiledOp;
/// The final reduction operation merging all the partial reductions.
Operation *mergeOp;
/// Initial op
Operation *initialOp;
/// Initial values used for reduction.
SmallVector<Value> initialValues;
/// The loop operations that iterate over the tiles.
SmallVector<LoopLikeOpInterface> loops;
};
Expand Down
4 changes: 2 additions & 2 deletions mlir/include/mlir/Interfaces/TilingInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
operation reduction. The tensor shape is equal to operation result
shape with new dimension for each non zero tile size.
}],
/*retType=*/"FailureOr<Operation*>",
/*retType=*/"FailureOr<SmallVector<Value>>",
/*methodName=*/"generateInitialTensorForPartialReduction",
/*args=*/(ins
"OpBuilder &":$b,
"Location ":$loc,
"Location":$loc,
"ArrayRef<OpFoldResult>":$sizes,
"ArrayRef<int>":$reductionDim),
/*methodBody=*/"",
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2523,7 +2523,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(

if (failed(result))
return emitDefaultSilenceableFailure(target);
results.push_back(result->initialOp);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
results.push_back(result->loops.front());
Expand Down Expand Up @@ -2574,7 +2575,8 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
diag.attachNote(target.getLoc()) << "target operation";
return diag;
}
results.push_back(result->initialOp);
for (Value initValue : result->initialValues)
results.push_back(initValue.getDefiningOp());
results.push_back(result->parallelTiledOp);
results.push_back(result->mergeOp);
results.push_back(result->loops);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ static Value createDestinationPassingStyleInitOperand(
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
PartialReductionOpInterface partialReductionIface =
llvm::cast<PartialReductionOpInterface>(op.getOperation());
FailureOr<Operation *> reductionNeutralTensorOp =
assert(op->getNumResults() == 1 && "Multiple results not supported.");
FailureOr<SmallVector<Value>> reductionNeutralTensor =
partialReductionIface.generateInitialTensorForPartialReduction(
builder, builder.getLoc(), shape, {});
assert(succeeded(reductionNeutralTensorOp));
builder.create<scf::YieldOp>(
reductionNeutralTensorOp.value()->getResult(0));
assert(succeeded(reductionNeutralTensor));
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
}
return ifOp.getResult(0);
}
Expand All @@ -173,8 +173,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
ImplicitLocOpBuilder &builder) {
// TODO: add support for multiple destination passing style initial value
// operands.
// PartialReductionOpInterface::generateInitialTensorForPartialReduction
// needs to also support multiple DPS initial operands.
assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported.");
SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands);
auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber();
Value spmdizedInitOperand =
Expand Down
17 changes: 9 additions & 8 deletions mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,12 +692,13 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
op, "reduction dimension must be mapped to threads");

// 1. Create the inital tensor value.
FailureOr<Operation *> identityTensor =
FailureOr<SmallVector<Value>> maybeInitTensors =
op.generateInitialTensorForPartialReduction(b, loc, numThreads,
reductionDim);
if (failed(identityTensor))
return b.notifyMatchFailure(op,
"cannot create a tensor of identity value.");
if (failed(maybeInitTensors))
return b.notifyMatchFailure(
op, "Failed to create inital tensors for partial reduction");
SmallVector<Value> &initTensors = maybeInitTensors.value();

// Gather destination tensors.
SmallVector<Value> dest;
Expand All @@ -715,8 +716,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(

// 2. Create the ForallOp with an empty region.
scf::ForallOp forallOp = b.create<scf::ForallOp>(
loc, getAsOpFoldResult(materializedNonZeroNumThreads),
(*identityTensor)->getResults(), mapping);
loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors,
mapping);

// 3. Calculate the tile offsets and sizes for the subsequent loop that will
// be nested under `forallOp`.
Expand All @@ -726,7 +727,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
/*nominalTileSizes=*/std::nullopt, tiledOffsets,
tiledSizes);

// 4. Clone the tileable op and update its destination operands to use the
// 4b. Clone the tileable op and update its destination operands to use the
// output bbArgs of the ForallOp.
SmallVector<Value> tilingResults;
ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
Expand Down Expand Up @@ -838,7 +839,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(

// 8. Return.
ForallReductionTilingResult results;
results.initialOp = *identityTensor;
results.initialValues = initTensors;
results.loops = forallOp;
results.parallelTiledOp = tiledOp;
results.mergeOp = mergeOp;
Expand Down
Loading
Loading