@@ -1103,28 +1103,39 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1103
1103
static void getGenericEffectsImpl (
1104
1104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1105
1105
&effects,
1106
- ValueRange results, const ValueRange inputOperands,
1106
+ LinalgOp linalgOp, ValueRange results, const ValueRange inputOperands,
1107
1107
ValueRange outputOperands) {
1108
- for (auto operand : inputOperands) {
1108
+ for (auto [index, operand] : llvm::enumerate ( inputOperands) ) {
1109
1109
if (!llvm::isa<MemRefType>(operand.getType ()))
1110
1110
continue ;
1111
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1112
- SideEffects::DefaultResource::get ());
1111
+ if (linalgOp.payloadUsesValueFromOperand (&linalgOp->getOpOperand (index))) {
1112
+ effects.emplace_back (MemoryEffects::Read::get (), 0 , true , operand,
1113
+ SideEffects::DefaultResource::get ());
1114
+ }
1113
1115
}
1114
- for (auto operand : outputOperands) {
1116
+ unsigned inputOperandSize = inputOperands.size ();
1117
+ unsigned usedOutputSize =
1118
+ linalgOp.getOpOperandsMatchingBBargs ().size () - inputOperandSize;
1119
+
1120
+ for (auto [index, operand] : llvm::enumerate (outputOperands)) {
1115
1121
if (!llvm::isa<MemRefType>(operand.getType ()))
1116
1122
continue ;
1117
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1118
- SideEffects::DefaultResource::get ());
1119
- effects.emplace_back (MemoryEffects::Write::get (), operand,
1123
+ if (index < usedOutputSize &&
1124
+ linalgOp.payloadUsesValueFromOperand (
1125
+ &linalgOp->getOpOperand (index + inputOperandSize))) {
1126
+ effects.emplace_back (MemoryEffects::Read::get (), 0 , true , operand,
1127
+ SideEffects::DefaultResource::get ());
1128
+ }
1129
+ effects.emplace_back (MemoryEffects::Write::get (), 0 , true , operand,
1120
1130
SideEffects::DefaultResource::get ());
1121
1131
}
1122
1132
}
1123
1133
1124
1134
void GenericOp::getEffects (
1125
1135
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1126
1136
&effects) {
1127
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1137
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()),
1138
+ getOperation ()->getResults (), getDpsInputs (),
1128
1139
getDpsInits ());
1129
1140
}
1130
1141
@@ -1473,7 +1484,8 @@ ArrayAttr MapOp::getIndexingMaps() {
1473
1484
void MapOp::getEffects (
1474
1485
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1475
1486
&effects) {
1476
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1487
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()),
1488
+ getOperation ()->getResults (), getDpsInputs (),
1477
1489
getDpsInits ());
1478
1490
}
1479
1491
@@ -1542,7 +1554,8 @@ ArrayAttr ReduceOp::getIndexingMaps() {
1542
1554
void ReduceOp::getEffects (
1543
1555
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1544
1556
&effects) {
1545
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1557
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()),
1558
+ getOperation ()->getResults (), getDpsInputs (),
1546
1559
getDpsInits ());
1547
1560
}
1548
1561
@@ -1827,7 +1840,8 @@ ArrayAttr TransposeOp::getIndexingMaps() {
1827
1840
void TransposeOp::getEffects (
1828
1841
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1829
1842
&effects) {
1830
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1843
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()),
1844
+ getOperation ()->getResults (), getDpsInputs (),
1831
1845
getDpsInits ());
1832
1846
}
1833
1847
@@ -1965,7 +1979,8 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
1965
1979
void BroadcastOp::getEffects (
1966
1980
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1967
1981
&effects) {
1968
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1982
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()),
1983
+ getOperation ()->getResults (), getDpsInputs (),
1969
1984
getDpsInits ());
1970
1985
}
1971
1986
@@ -2494,7 +2509,8 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
2494
2509
void SoftmaxOp::getEffects (
2495
2510
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2496
2511
&effects) {
2497
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
2512
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()),
2513
+ getOperation ()->getResults (), getDpsInputs (),
2498
2514
getDpsInits ());
2499
2515
}
2500
2516
0 commit comments