@@ -1122,29 +1122,35 @@ ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1122
1122
static void getGenericEffectsImpl (
1123
1123
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1124
1124
&effects,
1125
- ValueRange results, const ValueRange inputOperands,
1126
- ValueRange outputOperands) {
1127
- for (auto operand : inputOperands) {
1125
+ LinalgOp linalgOp) {
1126
+ SmallVector<Value> inputOperands = linalgOp. getDpsInputs ();
1127
+ for (auto [index, operand] : llvm::enumerate ( inputOperands) ) {
1128
1128
if (!llvm::isa<MemRefType>(operand.getType ()))
1129
1129
continue ;
1130
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1131
- SideEffects::DefaultResource::get ());
1130
+ if (linalgOp.payloadUsesValueFromOperand (&linalgOp->getOpOperand (index))) {
1131
+ effects.emplace_back (MemoryEffects::Read::get (), operand, 0 , true ,
1132
+ SideEffects::DefaultResource::get ());
1133
+ }
1132
1134
}
1133
- for (auto operand : outputOperands) {
1135
+ unsigned inputOperandSize = inputOperands.size ();
1136
+
1137
+ for (auto [index, operand] : llvm::enumerate (linalgOp.getDpsInits ())) {
1134
1138
if (!llvm::isa<MemRefType>(operand.getType ()))
1135
1139
continue ;
1136
- effects.emplace_back (MemoryEffects::Read::get (), operand,
1137
- SideEffects::DefaultResource::get ());
1138
- effects.emplace_back (MemoryEffects::Write::get (), operand,
1140
+ if (linalgOp.payloadUsesValueFromOperand (
1141
+ &linalgOp->getOpOperand (index + inputOperandSize))) {
1142
+ effects.emplace_back (MemoryEffects::Read::get (), operand, 0 , true ,
1143
+ SideEffects::DefaultResource::get ());
1144
+ }
1145
+ effects.emplace_back (MemoryEffects::Write::get (), operand, 0 , true ,
1139
1146
SideEffects::DefaultResource::get ());
1140
1147
}
1141
1148
}
1142
1149
1143
1150
void GenericOp::getEffects (
1144
1151
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1145
1152
&effects) {
1146
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1147
- getDpsInits ());
1153
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1148
1154
}
1149
1155
1150
1156
LogicalResult GenericOp::verify () { return success (); }
@@ -1492,8 +1498,7 @@ ArrayAttr MapOp::getIndexingMaps() {
1492
1498
void MapOp::getEffects (
1493
1499
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1494
1500
&effects) {
1495
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1496
- getDpsInits ());
1501
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1497
1502
}
1498
1503
1499
1504
// ===----------------------------------------------------------------------===//
@@ -1561,8 +1566,7 @@ ArrayAttr ReduceOp::getIndexingMaps() {
1561
1566
void ReduceOp::getEffects (
1562
1567
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1563
1568
&effects) {
1564
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1565
- getDpsInits ());
1569
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1566
1570
}
1567
1571
1568
1572
static ParseResult parseDenseI64ArrayAttr (OpAsmParser &parser,
@@ -1846,8 +1850,7 @@ ArrayAttr TransposeOp::getIndexingMaps() {
1846
1850
void TransposeOp::getEffects (
1847
1851
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1848
1852
&effects) {
1849
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1850
- getDpsInits ());
1853
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1851
1854
}
1852
1855
1853
1856
LogicalResult TransposeOp::fold (FoldAdaptor adaptor,
@@ -1984,8 +1987,7 @@ ArrayAttr BroadcastOp::getIndexingMaps() {
1984
1987
void BroadcastOp::getEffects (
1985
1988
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1986
1989
&effects) {
1987
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
1988
- getDpsInits ());
1990
+ getGenericEffectsImpl (effects, cast<LinalgOp>(getOperation ()));
1989
1991
}
1990
1992
1991
1993
void BroadcastOp::getCanonicalizationPatterns (RewritePatternSet &results,
@@ -2513,8 +2515,20 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
2513
2515
void SoftmaxOp::getEffects (
2514
2516
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2515
2517
&effects) {
2516
- getGenericEffectsImpl (effects, getOperation ()->getResults (), getDpsInputs (),
2517
- getDpsInits ());
2518
+ for (auto operand : getDpsInputs ()) {
2519
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2520
+ continue ;
2521
+ effects.emplace_back (MemoryEffects::Read::get (), operand,
2522
+ SideEffects::DefaultResource::get ());
2523
+ }
2524
+ for (auto operand : getDpsInits ()) {
2525
+ if (!llvm::isa<MemRefType>(operand.getType ()))
2526
+ continue ;
2527
+ effects.emplace_back (MemoryEffects::Read::get (), operand,
2528
+ SideEffects::DefaultResource::get ());
2529
+ effects.emplace_back (MemoryEffects::Write::get (), operand,
2530
+ SideEffects::DefaultResource::get ());
2531
+ }
2518
2532
}
2519
2533
2520
2534
// Helper functions for softmax decomposition.
0 commit comments