Skip to content

Commit a3bd52c

Browse files
committed
[mlir][linalg] Add more precise memory effects to linalg op
1 parent 79a6a7e commit a3bd52c

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,28 +1103,39 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11031103
static void getGenericEffectsImpl(
11041104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11051105
&effects,
1106-
ValueRange results, const ValueRange inputOperands,
1106+
LinalgOp linalgOp, ValueRange results, const ValueRange inputOperands,
11071107
ValueRange outputOperands) {
1108-
for (auto operand : inputOperands) {
1108+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
11091109
if (!llvm::isa<MemRefType>(operand.getType()))
11101110
continue;
1111-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1112-
SideEffects::DefaultResource::get());
1111+
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1112+
effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
1113+
SideEffects::DefaultResource::get());
1114+
}
11131115
}
1114-
for (auto operand : outputOperands) {
1116+
unsigned inputOperandSize = inputOperands.size();
1117+
unsigned usedOutputSize =
1118+
linalgOp.getOpOperandsMatchingBBargs().size() - inputOperandSize;
1119+
1120+
for (auto [index, operand] : llvm::enumerate(outputOperands)) {
11151121
if (!llvm::isa<MemRefType>(operand.getType()))
11161122
continue;
1117-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1118-
SideEffects::DefaultResource::get());
1119-
effects.emplace_back(MemoryEffects::Write::get(), operand,
1123+
if (index < usedOutputSize &&
1124+
linalgOp.payloadUsesValueFromOperand(
1125+
&linalgOp->getOpOperand(index + inputOperandSize))) {
1126+
effects.emplace_back(MemoryEffects::Read::get(), 0, true, operand,
1127+
SideEffects::DefaultResource::get());
1128+
}
1129+
effects.emplace_back(MemoryEffects::Write::get(), 0, true, operand,
11201130
SideEffects::DefaultResource::get());
11211131
}
11221132
}
11231133

11241134
void GenericOp::getEffects(
11251135
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11261136
&effects) {
1127-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1137+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
1138+
getOperation()->getResults(), getDpsInputs(),
11281139
getDpsInits());
11291140
}
11301141

@@ -1473,7 +1484,8 @@ ArrayAttr MapOp::getIndexingMaps() {
14731484
void MapOp::getEffects(
14741485
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14751486
&effects) {
1476-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1487+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
1488+
getOperation()->getResults(), getDpsInputs(),
14771489
getDpsInits());
14781490
}
14791491

@@ -1542,7 +1554,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15421554
void ReduceOp::getEffects(
15431555
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15441556
&effects) {
1545-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1557+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
1558+
getOperation()->getResults(), getDpsInputs(),
15461559
getDpsInits());
15471560
}
15481561

@@ -1827,7 +1840,8 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18271840
void TransposeOp::getEffects(
18281841
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18291842
&effects) {
1830-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1843+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
1844+
getOperation()->getResults(), getDpsInputs(),
18311845
getDpsInits());
18321846
}
18331847

@@ -1965,7 +1979,8 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19651979
void BroadcastOp::getEffects(
19661980
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19671981
&effects) {
1968-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1982+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
1983+
getOperation()->getResults(), getDpsInputs(),
19691984
getDpsInits());
19701985
}
19711986

@@ -2494,7 +2509,8 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
24942509
void SoftmaxOp::getEffects(
24952510
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
24962511
&effects) {
2497-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2512+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
2513+
getOperation()->getResults(), getDpsInputs(),
24982514
getDpsInits());
24992515
}
25002516

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ LogicalResult {0}::fold(FoldAdaptor,
659659
void {0}::getEffects(SmallVectorImpl<
660660
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
661661
if (hasPureTensorSemantics()) return;
662-
getGenericEffectsImpl(effects,
662+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()),
663663
getOperation()->getResults(), getDpsInputs(), getDpsInits());
664664
}
665665
)FMT";

0 commit comments

Comments
 (0)