Skip to content

Commit 8d3b0c8

Browse files
cxy-1993chenxunyu
authored andcommitted
[mlir][linalg] Add more precise memory effects to linalg op
1 parent 79a6a7e commit 8d3b0c8

File tree

2 files changed

+36
-23
lines changed

2 files changed

+36
-23
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,29 +1103,35 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11031103
static void getGenericEffectsImpl(
11041104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11051105
&effects,
1106-
ValueRange results, const ValueRange inputOperands,
1107-
ValueRange outputOperands) {
1108-
for (auto operand : inputOperands) {
1106+
LinalgOp linalgOp) {
1107+
ValueRange inputOperands = linalgOp.getDpsInputs();
1108+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
11091109
if (!llvm::isa<MemRefType>(operand.getType()))
11101110
continue;
1111-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1112-
SideEffects::DefaultResource::get());
1111+
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1112+
effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
1113+
SideEffects::DefaultResource::get());
1114+
}
11131115
}
1114-
for (auto operand : outputOperands) {
1116+
unsigned inputOperandSize = inputOperands.size();
1117+
1118+
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
11151119
if (!llvm::isa<MemRefType>(operand.getType()))
11161120
continue;
1117-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1118-
SideEffects::DefaultResource::get());
1119-
effects.emplace_back(MemoryEffects::Write::get(), operand,
1121+
if (linalgOp.payloadUsesValueFromOperand(
1122+
&linalgOp->getOpOperand(index + inputOperandSize))) {
1123+
effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
1124+
SideEffects::DefaultResource::get());
1125+
}
1126+
effects.emplace_back(MemoryEffects::Write::get(), operand, 0, true,
11201127
SideEffects::DefaultResource::get());
11211128
}
11221129
}
11231130

11241131
void GenericOp::getEffects(
11251132
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11261133
&effects) {
1127-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1128-
getDpsInits());
1134+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
11291135
}
11301136

11311137
LogicalResult GenericOp::verify() { return success(); }
@@ -1473,8 +1479,7 @@ ArrayAttr MapOp::getIndexingMaps() {
14731479
void MapOp::getEffects(
14741480
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14751481
&effects) {
1476-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1477-
getDpsInits());
1482+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
14781483
}
14791484

14801485
//===----------------------------------------------------------------------===//
@@ -1542,8 +1547,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15421547
void ReduceOp::getEffects(
15431548
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15441549
&effects) {
1545-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1546-
getDpsInits());
1550+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
15471551
}
15481552

15491553
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1827,8 +1831,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18271831
void TransposeOp::getEffects(
18281832
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18291833
&effects) {
1830-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1831-
getDpsInits());
1834+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
18321835
}
18331836

18341837
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1965,8 +1968,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19651968
void BroadcastOp::getEffects(
19661969
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19671970
&effects) {
1968-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1969-
getDpsInits());
1971+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
19701972
}
19711973

19721974
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2494,8 +2496,20 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
24942496
void SoftmaxOp::getEffects(
24952497
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
24962498
&effects) {
2497-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2498-
getDpsInits());
2499+
for (auto operand : getDpsInputs()) {
2500+
if (!llvm::isa<MemRefType>(operand.getType()))
2501+
continue;
2502+
effects.emplace_back(MemoryEffects::Read::get(), operand,
2503+
SideEffects::DefaultResource::get());
2504+
}
2505+
for (auto operand : getDpsInits()) {
2506+
if (!llvm::isa<MemRefType>(operand.getType()))
2507+
continue;
2508+
effects.emplace_back(MemoryEffects::Read::get(), operand,
2509+
SideEffects::DefaultResource::get());
2510+
effects.emplace_back(MemoryEffects::Write::get(), operand,
2511+
SideEffects::DefaultResource::get());
2512+
}
24992513
}
25002514

25012515
// Helper functions for softmax decomposition.

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,8 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
659659
void {0}::getEffects(SmallVectorImpl<
660660
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
661661
if (hasPureTensorSemantics()) return;
662-
getGenericEffectsImpl(effects,
663-
getOperation()->getResults(), getDpsInputs(), getDpsInits());
662+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
664663
}
665664
)FMT";
666665

0 commit comments

Comments
 (0)