Skip to content

Commit 779df1f

Browse files
committed
[mlir][linalg] Add more precise memory effects to linalg op
1 parent d7bb072 commit 779df1f

File tree

3 files changed

+39
-23
lines changed

3 files changed

+39
-23
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ def LinalgStructuredInterface
322322
/*args=*/(ins "OpOperand *":$opOperand),
323323
/*methodBody=*/"",
324324
/*defaultImplementation=*/[{
325+
if ($_op.getOperation()->getRegion(0).empty()) {
326+
return true;
327+
}
325328
unsigned bbArgNumber = opOperand->getOperandNumber();
326329
// Init tensors have uses.
327330
return !getBlock()->getArgument(bbArgNumber).use_empty();

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,29 +1122,35 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
11221122
static void getGenericEffectsImpl(
11231123
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11241124
&effects,
1125-
ValueRange results, const ValueRange inputOperands,
1126-
ValueRange outputOperands) {
1127-
for (auto operand : inputOperands) {
1125+
LinalgOp linalgOp) {
1126+
SmallVector<Value> inputOperands = linalgOp.getDpsInputs();
1127+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
11281128
if (!llvm::isa<MemRefType>(operand.getType()))
11291129
continue;
1130-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1131-
SideEffects::DefaultResource::get());
1130+
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1131+
effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
1132+
SideEffects::DefaultResource::get());
1133+
}
11321134
}
1133-
for (auto operand : outputOperands) {
1135+
unsigned inputOperandSize = inputOperands.size();
1136+
1137+
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
11341138
if (!llvm::isa<MemRefType>(operand.getType()))
11351139
continue;
1136-
effects.emplace_back(MemoryEffects::Read::get(), operand,
1137-
SideEffects::DefaultResource::get());
1138-
effects.emplace_back(MemoryEffects::Write::get(), operand,
1140+
if (linalgOp.payloadUsesValueFromOperand(
1141+
&linalgOp->getOpOperand(index + inputOperandSize))) {
1142+
effects.emplace_back(MemoryEffects::Read::get(), operand, 0, true,
1143+
SideEffects::DefaultResource::get());
1144+
}
1145+
effects.emplace_back(MemoryEffects::Write::get(), operand, 0, true,
11391146
SideEffects::DefaultResource::get());
11401147
}
11411148
}
11421149

11431150
void GenericOp::getEffects(
11441151
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
11451152
&effects) {
1146-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1147-
getDpsInits());
1153+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
11481154
}
11491155

11501156
LogicalResult GenericOp::verify() { return success(); }
@@ -1492,8 +1498,7 @@ ArrayAttr MapOp::getIndexingMaps() {
14921498
void MapOp::getEffects(
14931499
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
14941500
&effects) {
1495-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1496-
getDpsInits());
1501+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
14971502
}
14981503

14991504
//===----------------------------------------------------------------------===//
@@ -1561,8 +1566,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
15611566
void ReduceOp::getEffects(
15621567
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
15631568
&effects) {
1564-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1565-
getDpsInits());
1569+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
15661570
}
15671571

15681572
static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
@@ -1846,8 +1850,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
18461850
void TransposeOp::getEffects(
18471851
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
18481852
&effects) {
1849-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1850-
getDpsInits());
1853+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
18511854
}
18521855

18531856
LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
@@ -1984,8 +1987,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
19841987
void BroadcastOp::getEffects(
19851988
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
19861989
&effects) {
1987-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1988-
getDpsInits());
1990+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
19891991
}
19901992

19911993
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -2513,8 +2515,20 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
25132515
void SoftmaxOp::getEffects(
25142516
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
25152517
&effects) {
2516-
getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2517-
getDpsInits());
2518+
for (auto operand : getDpsInputs()) {
2519+
if (!llvm::isa<MemRefType>(operand.getType()))
2520+
continue;
2521+
effects.emplace_back(MemoryEffects::Read::get(), operand,
2522+
SideEffects::DefaultResource::get());
2523+
}
2524+
for (auto operand : getDpsInits()) {
2525+
if (!llvm::isa<MemRefType>(operand.getType()))
2526+
continue;
2527+
effects.emplace_back(MemoryEffects::Read::get(), operand,
2528+
SideEffects::DefaultResource::get());
2529+
effects.emplace_back(MemoryEffects::Write::get(), operand,
2530+
SideEffects::DefaultResource::get());
2531+
}
25182532
}
25192533

25202534
// Helper functions for softmax decomposition.

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -667,8 +667,7 @@ LogicalResult {0}::fold(FoldAdaptor,
667667
void {0}::getEffects(SmallVectorImpl<
668668
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
669669
if (hasPureTensorSemantics()) return;
670-
getGenericEffectsImpl(effects,
671-
getOperation()->getResults(), getDpsInputs(), getDpsInits());
670+
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
672671
}
673672
)FMT";
674673

0 commit comments

Comments
 (0)