@@ -1103,29 +1103,35 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1103
1103
static void getGenericEffectsImpl (
1104
1104
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1105
1105
&effects,
1106
- ValueRange results, const ValueRange inputOperands,
1107
- ValueRange outputOperands) {
1108
- for (auto operand : inputOperands) {
1106
+ LinalgOp linalgOp) {
1107
+ SmallVector<Value> inputOperands = linalgOp. getDpsInputs ();
1108
+ for (auto [index, operand] : llvm::enumerate ( inputOperands) ) {
1109
1109
if (!llvm::isa<MemRefType>(operand.getType ()))
1110
1110
continue ;
1111
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1112
- SideEffects::DefaultResource::get ());
1111
+ if (linalgOp.payloadUsesValueFromOperand (&linalgOp->getOpOperand (index))) {
1112
+ effects.emplace_back (MemoryEffects::Read::get (), operand, 0 , true ,
1113
+ SideEffects::DefaultResource::get ());
1114
+ }
1113
1115
}
1114
- for (auto operand : outputOperands) {
1116
+ unsigned inputOperandSize = inputOperands.size ();
1117
+
1118
+ for (auto [index, operand] : llvm::enumerate (linalgOp.getDpsInits ())) {
1115
1119
if (!llvm::isa<MemRefType>(operand.getType ()))
1116
1120
continue ;
1117
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1118
- SideEffects::DefaultResource::get ());
1119
- effects.emplace_back (MemoryEffects::Write::get (), operand,
1121
+ if (linalgOp.payloadUsesValueFromOperand (
1122
+ &linalgOp->getOpOperand (index + inputOperandSize))) {
1123
+ effects.emplace_back (MemoryEffects::Read::get (), operand, 0 , true ,
1124
+ SideEffects::DefaultResource::get ());
1125
+ }
1126
+ effects.emplace_back (MemoryEffects::Write::get (), operand, 0 , true ,
1120
1127
SideEffects::DefaultResource::get ());
1121
1128
}
1122
1129
}
1123
1130
1124
1131
void GenericOp::getEffects (
1125
1132
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1126
1133
&effects) {
1127
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1128
- getDpsInits ());
1134
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1129
1135
}
1130
1136
1131
1137
LogicalResult GenericOp::verify () { return success (); }
@@ -1473,8 +1479,7 @@ ArrayAttr MapOp::getIndexingMaps() {
1473
1479
void MapOp::getEffects (
1474
1480
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1475
1481
&effects) {
1476
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1477
- getDpsInits ());
1482
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1478
1483
}
1479
1484
1480
1485
// ===----------------------------------------------------------------------===//
@@ -1542,8 +1547,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
1542
1547
void ReduceOp::getEffects (
1543
1548
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1544
1549
&effects) {
1545
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1546
- getDpsInits ());
1550
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1547
1551
}
1548
1552
1549
1553
static ParseResult parseDenseI64ArrayAttr (OpAsmParser &parser,
@@ -1827,8 +1831,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
1827
1831
void TransposeOp::getEffects (
1828
1832
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1829
1833
&effects) {
1830
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1831
- getDpsInits ());
1834
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1832
1835
}
1833
1836
1834
1837
LogicalResult TransposeOp::fold (FoldAdaptor adaptor,
@@ -1965,8 +1968,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
1965
1968
void BroadcastOp::getEffects (
1966
1969
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1967
1970
&effects) {
1968
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1969
- getDpsInits ());
1971
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1970
1972
}
1971
1973
1972
1974
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -2494,8 +2496,20 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
2494
2496
void SoftmaxOp::getEffects (
2495
2497
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2496
2498
&effects) {
2497
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
2498
- getDpsInits ());
2499
+ for (auto operand : getDpsInputs ()) {
2500
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2501
+ continue ;
2502
+ effects.emplace_back (MemoryEffects::Read::get (), operand,
2503
+ SideEffects::DefaultResource::get ());
2504
+ }
2505
+ for (auto operand : getDpsInits ()) {
2506
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2507
+ continue ;
2508
+ effects.emplace_back (MemoryEffects::Read::get (), operand,
2509
+ SideEffects::DefaultResource::get ());
2510
+ effects.emplace_back (MemoryEffects::Write::get (), operand,
2511
+ SideEffects::DefaultResource::get ());
2512
+ }
2499
2513
}
2500
2514
2501
2515
// Helper functions for softmax decomposition.
0 commit comments