Skip to content

Commit 67321b7

Browse files
cxy-1993chenxunyu
authored andcommitted
[mlir][linalg] Add more precise memory effects to linalg op
1 parent 79a6a7e commit 67321b7

File tree

4 files changed

+48
-23
lines changed

4 files changed

+48
-23
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ def LinalgStructuredInterface
322322
/*args=*/(ins "OpOperand *":$opOperand),
323323
/*methodBody=*/"",
324324
/*defaultImplementation=*/[{
325+
if ($_op.getOperation()->getRegion(0).empty()) {
326+
return true;
327+
}
325328
unsigned bbArgNumber = opOperand->getOperandNumber();
326329
// Init tensors have uses.
327330
return !getBlock()->getArgument(bbArgNumber).use_empty();

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def MapOp : LinalgStructuredBase_Op<"map", [
289289

290290
bool payloadUsesValueFromOperand(OpOperand * opOperand) {
291291
if (isDpsInit(opOperand)) return false;
292+
if (getOperation()->getRegion(0).empty()) {
293+
return true;
294+
}
292295
return !getMatchingBlockArgument(opOperand).use_empty();
293296
}
294297

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,29 +1103,38 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11031103
static void getGenericEffectsImpl(
11041104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11051105
&effects,
1106-
ValueRange results, const ValueRange inputOperands,
1107-
ValueRange outputOperands) {
1108-
for (auto operand : inputOperands) {
1106+
LinalgOp linalgOp) {
1107+
SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
1108+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
11091109
if (!llvm::isa<MemRefType>(operand.getType()))
11101110
continue;
1111-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1112-
SideEffects::DefaultResource::get());
1111+
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1112+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1113+
/*effectOnFullRegion=*/true,
1114+
SideEffects::DefaultResource::get());
1115+
}
11131116
}
1114-
for (auto operand : outputOperands) {
1117+
unsigned inputOperandSize = inputOperands.size();
1118+
1119+
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
11151120
if (!llvm::isa<MemRefType>(operand.getType()))
11161121
continue;
1117-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1118-
SideEffects::DefaultResource::get());
1119-
effects.emplace_back(MemoryEffects::Write::get(), operand,
1122+
if (linalgOp.payloadUsesValueFromOperand(
1123+
&linalgOp->getOpOperand(index + inputOperandSize))) {
1124+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1125+
/*effectOnFullRegion=*/true,
1126+
SideEffects::DefaultResource::get());
1127+
}
1128+
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
1129+
/*effectOnFullRegion=*/true,
11201130
SideEffects::DefaultResource::get());
11211131
}
11221132
}
11231133

11241134
void GenericOp::getEffects(
11251135
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11261136
&effects) {
1127-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1128-
getDpsInits());
1137+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
11291138
}
11301139

11311140
LogicalResult GenericOp::verify() { return success(); }
@@ -1473,8 +1482,7 @@ ArrayAttr MapOp::getIndexingMaps() {
14731482
void MapOp::getEffects(
14741483
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14751484
&effects) {
1476-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1477-
getDpsInits());
1485+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
14781486
}
14791487

14801488
//===----------------------------------------------------------------------===//
@@ -1542,8 +1550,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15421550
void ReduceOp::getEffects(
15431551
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15441552
&effects) {
1545-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1546-
getDpsInits());
1553+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
15471554
}
15481555

15491556
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1827,8 +1834,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18271834
void TransposeOp::getEffects(
18281835
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18291836
&effects) {
1830-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1831-
getDpsInits());
1837+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
18321838
}
18331839

18341840
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1965,8 +1971,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19651971
void BroadcastOp::getEffects(
19661972
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19671973
&effects) {
1968-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1969-
getDpsInits());
1974+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
19701975
}
19711976

19721977
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2494,8 +2499,23 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
24942499
void SoftmaxOp::getEffects(
24952500
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
24962501
&effects) {
2497-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2498-
getDpsInits());
2502+
for (Value operand : getDpsInputs()) {
2503+
if (!llvm::isa<MemRefType>(operand.getType()))
2504+
continue;
2505+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2506+
/*effectOnFullRegion=*/true,
2507+
SideEffects::DefaultResource::get());
2508+
}
2509+
for (Value operand : getDpsInits()) {
2510+
if (!llvm::isa<MemRefType>(operand.getType()))
2511+
continue;
2512+
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2513+
/*effectOnFullRegion=*/true,
2514+
SideEffects::DefaultResource::get());
2515+
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
2516+
/*effectOnFullRegion=*/true,
2517+
SideEffects::DefaultResource::get());
2518+
}
24992519
}
25002520

25012521
// Helper functions for softmax decomposition.

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
659659
void {0}::getEffects(SmallVectorImpl<
660660
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
661661
if (hasPureTensorSemantics()) return;
662-
getGenericEffectsImpl(effects,
663-
getOperation()->getResults(), getDpsInputs(), getDpsInits());
662+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
664663
}
665664
)FMT";
666665

0 commit comments

Comments
 (0)