Skip to content

Commit 5a7765d

Browse files
committed
[mlir][side effect] refactor(*): Include more precise side effects
This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperands the operation reads or writes, rather than just recording the reading and writing of values.
1 parent 12fcca0 commit 5a7765d

File tree

8 files changed

+86
-28
lines changed

8 files changed

+86
-28
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class AffineDmaStartOp
107107

108108
/// Returns the source MemRefType for this DMA operation.
109109
Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
110+
OpOperand &getSrcMemRefMutable() {
111+
return getOperation()->getOpOperand(getSrcMemRefOperandIndex());
112+
}
110113
MemRefType getSrcMemRefType() {
111114
return cast<MemRefType>(getSrcMemRef().getType());
112115
}
@@ -117,7 +120,8 @@ class AffineDmaStartOp
117120
/// Returns the affine map used to access the source memref.
118121
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
119122
AffineMapAttr getSrcMapAttr() {
120-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
123+
return cast<AffineMapAttr>(
124+
*(*this)->getInherentAttr(getSrcMapAttrStrName()));
121125
}
122126

123127
/// Returns the source memref affine map indices for this DMA operation.
@@ -139,6 +143,9 @@ class AffineDmaStartOp
139143

140144
/// Returns the destination MemRefType for this DMA operation.
141145
Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
146+
OpOperand &getDstMemRefMutable() {
147+
return getOperation()->getOpOperand(getDstMemRefOperandIndex());
148+
}
142149
MemRefType getDstMemRefType() {
143150
return cast<MemRefType>(getDstMemRef().getType());
144151
}
@@ -156,7 +163,8 @@ class AffineDmaStartOp
156163
/// Returns the affine map used to access the destination memref.
157164
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
158165
AffineMapAttr getDstMapAttr() {
159-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
166+
return cast<AffineMapAttr>(
167+
*(*this)->getInherentAttr(getDstMapAttrStrName()));
160168
}
161169

162170
/// Returns the destination memref indices for this DMA operation.
@@ -173,6 +181,9 @@ class AffineDmaStartOp
173181

174182
/// Returns the Tag MemRef for this DMA operation.
175183
Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
184+
OpOperand &getTagMemRefMutable() {
185+
return getOperation()->getOpOperand(getTagMemRefOperandIndex());
186+
}
176187
MemRefType getTagMemRefType() {
177188
return cast<MemRefType>(getTagMemRef().getType());
178189
}
@@ -185,7 +196,8 @@ class AffineDmaStartOp
185196
/// Returns the affine map used to access the tag memref.
186197
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
187198
AffineMapAttr getTagMapAttr() {
188-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
199+
return cast<AffineMapAttr>(
200+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
189201
}
190202

191203
/// Returns the tag memref indices for this DMA operation.
@@ -300,14 +312,16 @@ class AffineDmaWaitOp
300312

301313
/// Returns the Tag MemRef associated with the DMA operation being waited on.
302314
Value getTagMemRef() { return getOperand(0); }
315+
OpOperand &getTagMemRefMutable() { return getOperation()->getOpOperand(0); }
303316
MemRefType getTagMemRefType() {
304317
return cast<MemRefType>(getTagMemRef().getType());
305318
}
306319

307320
/// Returns the affine map used to access the tag memref.
308321
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
309322
AffineMapAttr getTagMapAttr() {
310-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
323+
return cast<AffineMapAttr>(
324+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
311325
}
312326

313327
/// Returns the tag memref index for this DMA operation.

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
706706
let extraClassDeclaration = [{
707707
// Returns the source MemRefType for this DMA operation.
708708
Value getSrcMemRef() { return getOperand(0); }
709+
OpOperand &getSrcMemRefMutable() { return getOperation()->getOpOperand(0); }
709710
// Returns the rank (number of indices) of the source MemRefType.
710711
unsigned getSrcMemRefRank() {
711712
return ::llvm::cast<MemRefType>(getSrcMemRef().getType()).getRank();
@@ -718,6 +719,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
718719

719720
// Returns the destination MemRefType for this DMA operations.
720721
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
722+
OpOperand &getDstMemRefMutable() { return getOperation()->getOpOperand(1 + getSrcMemRefRank()); }
721723
// Returns the rank (number of indices) of the destination MemRefType.
722724
unsigned getDstMemRefRank() {
723725
return ::llvm::cast<MemRefType>(getDstMemRef().getType()).getRank();
@@ -745,6 +747,9 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
745747
Value getTagMemRef() {
746748
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
747749
}
750+
OpOperand &getTagMemRefMutable() {
751+
return getOperation()->getOpOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
752+
}
748753

749754
// Returns the rank (number of indices) of the tag MemRefType.
750755
unsigned getTagMemRefRank() {
@@ -801,11 +806,11 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
801806
void getEffects(
802807
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
803808
effects) {
804-
effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
809+
effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
805810
SideEffects::DefaultResource::get());
806-
effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
811+
effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
807812
SideEffects::DefaultResource::get());
808-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
813+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
809814
SideEffects::DefaultResource::get());
810815
}
811816
}];
@@ -852,7 +857,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
852857
void getEffects(
853858
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
854859
effects) {
855-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
860+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
856861
SideEffects::DefaultResource::get());
857862
}
858863
}];

mlir/include/mlir/Interfaces/SideEffectInterfaces.h

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,20 @@ class EffectInstance {
149149
Resource *resource = DefaultResource::get())
150150
: effect(effect), resource(resource), value(value), stage(0),
151151
effectOnFullRegion(false) {}
152+
EffectInstance(EffectT *effect, OpOperand *opd,
153+
Resource *resource = DefaultResource::get())
154+
: effect(effect), resource(resource), value(opd), stage(0),
155+
effectOnFullRegion(false) {}
152156
EffectInstance(EffectT *effect, Value value, int stage,
153157
bool effectOnFullRegion,
154158
Resource *resource = DefaultResource::get())
155159
: effect(effect), resource(resource), value(value), stage(stage),
156160
effectOnFullRegion(effectOnFullRegion) {}
161+
EffectInstance(EffectT *effect, OpOperand *opd, int stage,
162+
bool effectOnFullRegion,
163+
Resource *resource = DefaultResource::get())
164+
: effect(effect), resource(resource), value(opd), stage(stage),
165+
effectOnFullRegion(effectOnFullRegion) {}
157166
EffectInstance(EffectT *effect, SymbolRefAttr symbol,
158167
Resource *resource = DefaultResource::get())
159168
: effect(effect), resource(resource), value(symbol), stage(0),
@@ -176,12 +185,21 @@ class EffectInstance {
176185
Resource *resource = DefaultResource::get())
177186
: effect(effect), resource(resource), value(value),
178187
parameters(parameters), stage(0), effectOnFullRegion(false) {}
188+
EffectInstance(EffectT *effect, OpOperand *opd, Attribute parameters,
189+
Resource *resource = DefaultResource::get())
190+
: effect(effect), resource(resource), value(opd), parameters(parameters),
191+
stage(0), effectOnFullRegion(false) {}
179192
EffectInstance(EffectT *effect, Value value, Attribute parameters, int stage,
180193
bool effectOnFullRegion,
181194
Resource *resource = DefaultResource::get())
182195
: effect(effect), resource(resource), value(value),
183196
parameters(parameters), stage(stage),
184197
effectOnFullRegion(effectOnFullRegion) {}
198+
EffectInstance(EffectT *effect, OpOperand *opd, Attribute parameters,
199+
int stage, bool effectOnFullRegion,
200+
Resource *resource = DefaultResource::get())
201+
: effect(effect), resource(resource), value(opd), parameters(parameters),
202+
stage(stage), effectOnFullRegion(effectOnFullRegion) {}
185203
EffectInstance(EffectT *effect, SymbolRefAttr symbol, Attribute parameters,
186204
Resource *resource = DefaultResource::get())
187205
: effect(effect), resource(resource), value(symbol),
@@ -199,7 +217,17 @@ class EffectInstance {
199217
/// Return the value the effect is applied on, or nullptr if there isn't a
200218
/// known value being affected.
201219
Value getValue() const {
202-
return value ? llvm::dyn_cast_if_present<Value>(value) : Value();
220+
if (!value || llvm::isa_and_present<SymbolRefAttr>(value)) {
221+
return Value();
222+
}
223+
if (Value v = llvm::dyn_cast_if_present<Value>(value)) {
224+
return v;
225+
}
226+
return cast_if_present<OpOperand *>(value)->get();
227+
}
228+
229+
OpOperand *getOpOperand() const {
230+
return value ? dyn_cast_if_present<OpOperand *>(value) : nullptr;
203231
}
204232

205233
/// Return the symbol reference the effect is applied on, or nullptr if there
@@ -229,7 +257,7 @@ class EffectInstance {
229257
Resource *resource;
230258

231259
/// The Symbol or Value that the effect applies to. This is optionally null.
232-
PointerUnion<SymbolRefAttr, Value> value;
260+
PointerUnion<SymbolRefAttr, Value, OpOperand *> value;
233261

234262
/// Additional parameters of the effect instance. An attribute is used for
235263
/// type-safe structured storage and context-based uniquing. Concrete effects

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,11 +1703,11 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
17031703
void AffineDmaStartOp::getEffects(
17041704
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
17051705
&effects) {
1706-
effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
1706+
effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
17071707
SideEffects::DefaultResource::get());
1708-
effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
1708+
effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
17091709
SideEffects::DefaultResource::get());
1710-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1710+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
17111711
SideEffects::DefaultResource::get());
17121712
}
17131713

@@ -1793,7 +1793,7 @@ LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
17931793
void AffineDmaWaitOp::getEffects(
17941794
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
17951795
&effects) {
1796-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1796+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
17971797
SideEffects::DefaultResource::get());
17981798
}
17991799

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ void MaterializeInDestinationOp::getEffects(
728728
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
729729
&effects) {
730730
if (isa<BaseMemRefType>(getDest().getType()))
731-
effects.emplace_back(MemoryEffects::Write::get(), getDest(),
731+
effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
732732
SideEffects::DefaultResource::get());
733733
}
734734

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ Type GEPOp::getResultPtrElementType() {
825825
void LoadOp::getEffects(
826826
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
827827
&effects) {
828-
effects.emplace_back(MemoryEffects::Read::get(), getAddr());
828+
effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
829829
// Volatile operations can have target-specific read-write effects on
830830
// memory besides the one referred to by the pointer operand.
831831
// Similarly, atomic operations that are monotonic or stricter cause
@@ -902,7 +902,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
902902
void StoreOp::getEffects(
903903
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
904904
&effects) {
905-
effects.emplace_back(MemoryEffects::Write::get(), getAddr());
905+
effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
906906
// Volatile operations can have target-specific read-write effects on
907907
// memory besides the one referred to by the pointer operand.
908908
// Similarly, atomic operations that are monotonic or stricter cause

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,8 @@ static void getGenericEffectsImpl(
11281128
if (!llvm::isa<MemRefType>(operand.getType()))
11291129
continue;
11301130
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
1131-
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1131+
effects.emplace_back(MemoryEffects::Read::get(),
1132+
&linalgOp->getOpOperand(index), /*stage=*/0,
11321133
/*effectOnFullRegion=*/true,
11331134
SideEffects::DefaultResource::get());
11341135
}
@@ -1138,13 +1139,16 @@ static void getGenericEffectsImpl(
11381139
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
11391140
if (!llvm::isa<MemRefType>(operand.getType()))
11401141
continue;
1142+
unsigned operandIdx = index + inputOperandSize;
11411143
if (linalgOp.payloadUsesValueFromOperand(
1142-
&linalgOp->getOpOperand(index + inputOperandSize))) {
1143-
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
1144+
&linalgOp->getOpOperand(operandIdx))) {
1145+
effects.emplace_back(MemoryEffects::Read::get(),
1146+
&linalgOp->getOpOperand(operandIdx), /*stage=*/0,
11441147
/*effectOnFullRegion=*/true,
11451148
SideEffects::DefaultResource::get());
11461149
}
1147-
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
1150+
effects.emplace_back(MemoryEffects::Write::get(),
1151+
&linalgOp->getOpOperand(operandIdx), /*stage=*/0,
11481152
/*effectOnFullRegion=*/true,
11491153
SideEffects::DefaultResource::get());
11501154
}
@@ -2546,20 +2550,27 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
25462550
void SoftmaxOp::getEffects(
25472551
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
25482552
&effects) {
2549-
for (Value operand : getDpsInputs()) {
2553+
SmallVector<Value> inputOperands = getDpsInputs();
2554+
for (auto [index, operand] : llvm::enumerate(inputOperands)) {
25502555
if (!llvm::isa<MemRefType>(operand.getType()))
25512556
continue;
2552-
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2557+
effects.emplace_back(MemoryEffects::Read::get(),
2558+
&getOperation()->getOpOperand(index), /*stage=*/0,
25532559
/*effectOnFullRegion=*/true,
25542560
SideEffects::DefaultResource::get());
25552561
}
2556-
for (Value operand : getDpsInits()) {
2562+
2563+
unsigned inputOperandSize = inputOperands.size();
2564+
for (auto [index, operand] : llvm::enumerate(getDpsInits())) {
25572565
if (!llvm::isa<MemRefType>(operand.getType()))
25582566
continue;
2559-
effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
2567+
unsigned operandIdx = index + inputOperandSize;
2568+
effects.emplace_back(MemoryEffects::Read::get(),
2569+
&getOperation()->getOpOperand(operandIdx), /*stage=*/0,
25602570
/*effectOnFullRegion=*/true,
25612571
SideEffects::DefaultResource::get());
2562-
effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
2572+
effects.emplace_back(MemoryEffects::Write::get(),
2573+
&getOperation()->getOpOperand(operandIdx), /*stage=*/0,
25632574
/*effectOnFullRegion=*/true,
25642575
SideEffects::DefaultResource::get());
25652576
}

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4123,7 +4123,7 @@ void TransferReadOp::getEffects(
41234123
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
41244124
&effects) {
41254125
if (llvm::isa<MemRefType>(getShapedType()))
4126-
effects.emplace_back(MemoryEffects::Read::get(), getSource(),
4126+
effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
41274127
SideEffects::DefaultResource::get());
41284128
}
41294129

@@ -4497,7 +4497,7 @@ void TransferWriteOp::getEffects(
44974497
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
44984498
&effects) {
44994499
if (llvm::isa<MemRefType>(getShapedType()))
4500-
effects.emplace_back(MemoryEffects::Write::get(), getSource(),
4500+
effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
45014501
SideEffects::DefaultResource::get());
45024502
}
45034503

0 commit comments

Comments
 (0)