Skip to content

Commit 24e33b5

Browse files
authored
[mlir] Implement DestinationStyleOpInterface for scf::ForallOp (#66981)
`scf::ForallOp` has `shared_outs` tensor operands which are used to insert partial results into in the parallel terminator. The `scf::ForallOp` returns one tensor for each `shared_out` which then contains the combined result from all threads. Since the parallel terminator cannot change the shape of the `shared_out`, ForallOp is a `DestinationStyleOp` and this patch implements the interface by declaring the `outputs` operands as `inits` in the language of the DPS interface. For this change to work, we need to add an exception to the Pattern that folds `tensor.cast` Ops into DPS Ops because `scf::Forall` needs special handling of its `BlockArgument` Type during this folding.
1 parent 1767d81 commit 24e33b5

File tree

3 files changed

+12
-0
lines changed

3 files changed

+12
-0
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCF.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinTypes.h"
2020
#include "mlir/IR/RegionKindInterface.h"
2121
#include "mlir/Interfaces/ControlFlowInterfaces.h"
22+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2223
#include "mlir/Interfaces/InferTypeOpInterface.h"
2324
#include "mlir/Interfaces/LoopLikeInterface.h"
2425
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td"
1717
include "mlir/Interfaces/LoopLikeInterface.td"
1818
include "mlir/IR/RegionKindInterface.td"
1919
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
20+
include "mlir/Interfaces/DestinationStyleOpInterface.td"
2021
include "mlir/Interfaces/InferTypeOpInterface.td"
2122
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
2223
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -333,6 +334,7 @@ def ForallOp : SCF_Op<"forall", [
333334
RecursiveMemoryEffects,
334335
SingleBlockImplicitTerminator<"scf::InParallelOp">,
335336
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
337+
DestinationStyleOpInterface
336338
]> {
337339
let summary = "evaluate a block multiple times in parallel";
338340
let description = [{
@@ -630,6 +632,9 @@ def ForallOp : SCF_Op<"forall", [
630632
Location loc);
631633

632634
InParallelOp getTerminator();
635+
636+
// Declare the shared_outs as inits/outs to DestinationStyleOpInterface.
637+
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
633638
}];
634639
}
635640

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/OpDefinition.h"
2323
#include "mlir/IR/TypeUtilities.h"
2424
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
25+
#include "mlir/Interfaces/LoopLikeInterface.h"
2526
#include "mlir/Support/MathExtras.h"
2627
#include "llvm/ADT/DenseSet.h"
2728
#include "llvm/ADT/STLExtras.h"
@@ -3970,6 +3971,11 @@ struct FoldTensorCastProducerOp
39703971
if (isa<InsertSliceOp>(op.getOperation()))
39713972
return failure();
39723973

3974+
// Exclude DPS ops that are also LoopLike from this interface as they
3975+
// might need special handling of attached regions.
3976+
if (isa<LoopLikeOpInterface>(op.getOperation()))
3977+
return failure();
3978+
39733979
// If no operand comes from a tensor::CastOp and can be folded then fail.
39743980
bool hasTensorCastOperand =
39753981
llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {

0 commit comments

Comments
 (0)