Skip to content

Commit 3dccee5

Browse files
committed
refactor(*): Include more precise side effects
This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperand/OpResult/BlockArgument the operation reads or writes, rather than just recording the reading and writing of values. This allows for convenient use of precise side effects to achieve analysis and optimization. Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
1 parent 12fcca0 commit 3dccee5

File tree

33 files changed

+622
-299
lines changed

33 files changed

+622
-299
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class AffineDmaStartOp
107107

108108
/// Returns the source MemRefType for this DMA operation.
109109
Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
110+
OpOperand &getSrcMemRefMutable() {
111+
return getOperation()->getOpOperand(getSrcMemRefOperandIndex());
112+
}
110113
MemRefType getSrcMemRefType() {
111114
return cast<MemRefType>(getSrcMemRef().getType());
112115
}
@@ -117,7 +120,8 @@ class AffineDmaStartOp
117120
/// Returns the affine map used to access the source memref.
118121
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
119122
AffineMapAttr getSrcMapAttr() {
120-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
123+
return cast<AffineMapAttr>(
124+
*(*this)->getInherentAttr(getSrcMapAttrStrName()));
121125
}
122126

123127
/// Returns the source memref affine map indices for this DMA operation.
@@ -139,6 +143,9 @@ class AffineDmaStartOp
139143

140144
/// Returns the destination MemRefType for this DMA operation.
141145
Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
146+
OpOperand &getDstMemRefMutable() {
147+
return getOperation()->getOpOperand(getDstMemRefOperandIndex());
148+
}
142149
MemRefType getDstMemRefType() {
143150
return cast<MemRefType>(getDstMemRef().getType());
144151
}
@@ -156,7 +163,8 @@ class AffineDmaStartOp
156163
/// Returns the affine map used to access the destination memref.
157164
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
158165
AffineMapAttr getDstMapAttr() {
159-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
166+
return cast<AffineMapAttr>(
167+
*(*this)->getInherentAttr(getDstMapAttrStrName()));
160168
}
161169

162170
/// Returns the destination memref indices for this DMA operation.
@@ -173,6 +181,9 @@ class AffineDmaStartOp
173181

174182
/// Returns the Tag MemRef for this DMA operation.
175183
Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
184+
OpOperand &getTagMemRefMutable() {
185+
return getOperation()->getOpOperand(getTagMemRefOperandIndex());
186+
}
176187
MemRefType getTagMemRefType() {
177188
return cast<MemRefType>(getTagMemRef().getType());
178189
}
@@ -185,7 +196,8 @@ class AffineDmaStartOp
185196
/// Returns the affine map used to access the tag memref.
186197
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
187198
AffineMapAttr getTagMapAttr() {
188-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
199+
return cast<AffineMapAttr>(
200+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
189201
}
190202

191203
/// Returns the tag memref indices for this DMA operation.
@@ -300,14 +312,16 @@ class AffineDmaWaitOp
300312

301313
/// Returns the Tag MemRef associated with the DMA operation being waited on.
302314
Value getTagMemRef() { return getOperand(0); }
315+
OpOperand &getTagMemRefMutable() { return getOperation()->getOpOperand(0); }
303316
MemRefType getTagMemRefType() {
304317
return cast<MemRefType>(getTagMemRef().getType());
305318
}
306319

307320
/// Returns the affine map used to access the tag memref.
308321
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
309322
AffineMapAttr getTagMapAttr() {
310-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
323+
return cast<AffineMapAttr>(
324+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
311325
}
312326

313327
/// Returns the tag memref index for this DMA operation.

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
706706
let extraClassDeclaration = [{
707707
// Returns the source MemRefType for this DMA operation.
708708
Value getSrcMemRef() { return getOperand(0); }
709+
OpOperand &getSrcMemRefMutable() { return getOperation()->getOpOperand(0); }
709710
// Returns the rank (number of indices) of the source MemRefType.
710711
unsigned getSrcMemRefRank() {
711712
return ::llvm::cast<MemRefType>(getSrcMemRef().getType()).getRank();
@@ -718,6 +719,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
718719

719720
// Returns the destination MemRefType for this DMA operations.
720721
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
722+
OpOperand &getDstMemRefMutable() { return getOperation()->getOpOperand(1 + getSrcMemRefRank()); }
721723
// Returns the rank (number of indices) of the destination MemRefType.
722724
unsigned getDstMemRefRank() {
723725
return ::llvm::cast<MemRefType>(getDstMemRef().getType()).getRank();
@@ -745,6 +747,9 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
745747
Value getTagMemRef() {
746748
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
747749
}
750+
OpOperand &getTagMemRefMutable() {
751+
return getOperation()->getOpOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
752+
}
748753

749754
// Returns the rank (number of indices) of the tag MemRefType.
750755
unsigned getTagMemRefRank() {
@@ -801,11 +806,11 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
801806
void getEffects(
802807
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
803808
effects) {
804-
effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
809+
effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
805810
SideEffects::DefaultResource::get());
806-
effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
811+
effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
807812
SideEffects::DefaultResource::get());
808-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
813+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
809814
SideEffects::DefaultResource::get());
810815
}
811816
}];
@@ -852,7 +857,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
852857
void getEffects(
853858
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
854859
effects) {
855-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
860+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
856861
SideEffects::DefaultResource::get());
857862
}
858863
}];

mlir/include/mlir/Dialect/Transform/Interfaces/MatchInterfaces.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class AtMostOneOpMatcherOpTrait
102102
}
103103

104104
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
105-
onlyReadsHandle(this->getOperation()->getOperands(), effects);
106-
producesHandle(this->getOperation()->getResults(), effects);
105+
onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
106+
producesHandle(this->getOperation()->getOpResults(), effects);
107107
onlyReadsPayload(effects);
108108
}
109109
};
@@ -163,8 +163,8 @@ class SingleValueMatcherOpTrait
163163
}
164164

165165
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
166-
onlyReadsHandle(this->getOperation()->getOperands(), effects);
167-
producesHandle(this->getOperation()->getResults(), effects);
166+
onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
167+
producesHandle(this->getOperation()->getOpResults(), effects);
168168
onlyReadsPayload(effects);
169169
}
170170
};

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,11 +1261,13 @@ struct PayloadIRResource
12611261
/// - consumes = Read + Free,
12621262
/// - produces = Allocate + Write,
12631263
/// - onlyReads = Read.
1264-
void consumesHandle(ValueRange handles,
1264+
void consumesHandle(MutableArrayRef<OpOperand> handles,
12651265
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1266-
void producesHandle(ValueRange handles,
1266+
void producesHandle(ResultRange handles,
12671267
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1268-
void onlyReadsHandle(ValueRange handles,
1268+
void producesHandle(MutableArrayRef<BlockArgument> handles,
1269+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1270+
void onlyReadsHandle(MutableArrayRef<OpOperand> handles,
12691271
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
12701272

12711273
/// Checks whether the transform op consumes the given handle.
@@ -1296,8 +1298,8 @@ class FunctionalStyleTransformOpTrait
12961298
/// the results by allocating and writing it and reads/writes the payload IR
12971299
/// in the process.
12981300
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1299-
consumesHandle(this->getOperation()->getOperands(), effects);
1300-
producesHandle(this->getOperation()->getResults(), effects);
1301+
consumesHandle(this->getOperation()->getOpOperands(), effects);
1302+
producesHandle(this->getOperation()->getOpResults(), effects);
13011303
modifiesPayload(effects);
13021304
}
13031305

@@ -1322,8 +1324,8 @@ class NavigationTransformOpTrait
13221324
/// This op produces handles to the Payload IR without consuming the original
13231325
/// handles and without modifying the IR itself.
13241326
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1325-
onlyReadsHandle(this->getOperation()->getOperands(), effects);
1326-
producesHandle(this->getOperation()->getResults(), effects);
1327+
onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
1328+
producesHandle(this->getOperation()->getOpResults(), effects);
13271329
if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) {
13281330
return isa<TransformHandleTypeInterface,
13291331
TransformValueHandleTypeInterface>(t);

mlir/include/mlir/Interfaces/SideEffectInterfaces.h

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,19 @@ class EffectInstance {
145145
Resource *resource = DefaultResource::get())
146146
: effect(effect), resource(resource), stage(stage),
147147
effectOnFullRegion(effectOnFullRegion) {}
148-
EffectInstance(EffectT *effect, Value value,
148+
template <typename T,
149+
std::enable_if_t<
150+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
151+
bool> = true>
152+
EffectInstance(EffectT *effect, T value,
149153
Resource *resource = DefaultResource::get())
150154
: effect(effect), resource(resource), value(value), stage(0),
151155
effectOnFullRegion(false) {}
152-
EffectInstance(EffectT *effect, Value value, int stage,
153-
bool effectOnFullRegion,
156+
template <typename T,
157+
std::enable_if_t<
158+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
159+
bool> = true>
160+
EffectInstance(EffectT *effect, T value, int stage, bool effectOnFullRegion,
154161
Resource *resource = DefaultResource::get())
155162
: effect(effect), resource(resource), value(value), stage(stage),
156163
effectOnFullRegion(effectOnFullRegion) {}
@@ -172,11 +179,19 @@ class EffectInstance {
172179
Resource *resource = DefaultResource::get())
173180
: effect(effect), resource(resource), parameters(parameters),
174181
stage(stage), effectOnFullRegion(effectOnFullRegion) {}
175-
EffectInstance(EffectT *effect, Value value, Attribute parameters,
182+
template <typename T,
183+
std::enable_if_t<
184+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
185+
bool> = true>
186+
EffectInstance(EffectT *effect, T value, Attribute parameters,
176187
Resource *resource = DefaultResource::get())
177188
: effect(effect), resource(resource), value(value),
178189
parameters(parameters), stage(0), effectOnFullRegion(false) {}
179-
EffectInstance(EffectT *effect, Value value, Attribute parameters, int stage,
190+
template <typename T,
191+
std::enable_if_t<
192+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
193+
bool> = true>
194+
EffectInstance(EffectT *effect, T value, Attribute parameters, int stage,
180195
bool effectOnFullRegion,
181196
Resource *resource = DefaultResource::get())
182197
: effect(effect), resource(resource), value(value),
@@ -199,7 +214,26 @@ class EffectInstance {
199214
/// Return the value the effect is applied on, or nullptr if there isn't a
200215
/// known value being affected.
201216
Value getValue() const {
202-
return value ? llvm::dyn_cast_if_present<Value>(value) : Value();
217+
if (!value || llvm::isa_and_present<SymbolRefAttr>(value)) {
218+
return Value();
219+
}
220+
if (OpOperand *operand = llvm::dyn_cast_if_present<OpOperand *>(value)) {
221+
return operand->get();
222+
}
223+
if (OpResult result = llvm::dyn_cast_if_present<OpResult>(value)) {
224+
return result;
225+
}
226+
return cast_if_present<BlockArgument>(value);
227+
}
228+
229+
/// Returns the OpOperand effect is applied on, or nullptr if there isn't a
230+
/// known value being effected.
231+
template <typename T,
232+
std::enable_if_t<
233+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
234+
bool> = true>
235+
T getEffectValue() const {
236+
return value ? dyn_cast_if_present<T>(value) : nullptr;
203237
}
204238

205239
/// Return the symbol reference the effect is applied on, or nullptr if there
@@ -229,7 +263,7 @@ class EffectInstance {
229263
Resource *resource;
230264

231265
/// The Symbol or Value that the effect applies to. This is optionally null.
232-
PointerUnion<SymbolRefAttr, Value> value;
266+
PointerUnion<SymbolRefAttr, OpOperand *, OpResult, BlockArgument> value;
233267

234268
/// Additional parameters of the effect instance. An attribute is used for
235269
/// type-safe structured storage and context-based uniquing. Concrete effects
@@ -348,17 +382,32 @@ struct Write : public Effect::Base<Write> {};
348382
// SideEffect Utilities
349383
//===----------------------------------------------------------------------===//
350384

385+
/// Returns true if `op` has only an effect of type `EffectTy`.
386+
template <typename EffectTy>
387+
bool hasSingleEffect(Operation *op);
388+
351389
/// Returns true if `op` has only an effect of type `EffectTy` (and of no other
352-
/// type) on `value`. If no value is provided, simply check if effects of that
353-
/// type and only of that type are present.
390+
/// type) on `value`.
354391
template <typename EffectTy>
355-
bool hasSingleEffect(Operation *op, Value value = nullptr);
392+
bool hasSingleEffect(Operation *op, Value value);
393+
394+
/// Returns true if `op` has only an effect of type `EffectTy` (and of no other
395+
/// type) on `value` of type `ValueTy`.
396+
template <typename ValueTy, typename EffectTy>
397+
bool hasSingleEffect(Operation *op, ValueTy value);
356398

357-
/// Returns true if `op` has an effect of type `EffectTy` on `value`. If no
358-
/// `value` is provided, simply check if effects of the given type(s) are
359-
/// present.
399+
/// Returns true if `op` has an effect of type `EffectTy`.
360400
template <typename... EffectTys>
361-
bool hasEffect(Operation *op, Value value = nullptr);
401+
bool hasEffect(Operation *op);
402+
403+
/// Returns true if `op` has an effect of type `EffectTy` on `value`.
404+
template <typename... EffectTys>
405+
bool hasEffect(Operation *op, Value value);
406+
407+
/// Returns true if `op` has an effect of type `EffectTy` on `value` of type
408+
/// `ValueTy`.
409+
template <typename ValueTy, typename... EffectTys>
410+
bool hasEffect(Operation *op, ValueTy value);
362411

363412
/// Return true if the given operation is unused, and has no side effects on
364413
/// memory that prevent erasing.

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,11 +1703,11 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
17031703
void AffineDmaStartOp::getEffects(
17041704
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
17051705
&effects) {
1706-
effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
1706+
effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
17071707
SideEffects::DefaultResource::get());
1708-
effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
1708+
effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
17091709
SideEffects::DefaultResource::get());
1710-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1710+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
17111711
SideEffects::DefaultResource::get());
17121712
}
17131713

@@ -1793,7 +1793,7 @@ LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
17931793
void AffineDmaWaitOp::getEffects(
17941794
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
17951795
&effects) {
1796-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1796+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
17971797
SideEffects::DefaultResource::get());
17981798
}
17991799

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
142142

143143
void SimplifyBoundedAffineOpsOp::getEffects(
144144
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
145-
consumesHandle(getTarget(), effects);
146-
for (Value v : getBoundedValues())
147-
onlyReadsHandle(v, effects);
145+
consumesHandle(getTargetMutable(), effects);
146+
for (OpOperand &operand : getBoundedValuesMutable())
147+
onlyReadsHandle(operand, effects);
148148
modifiesPayload(effects);
149149
}
150150

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ void MaterializeInDestinationOp::getEffects(
728728
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
729729
&effects) {
730730
if (isa<BaseMemRefType>(getDest().getType()))
731-
effects.emplace_back(MemoryEffects::Write::get(), getDest(),
731+
effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
732732
SideEffects::DefaultResource::get());
733733
}
734734

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ DiagnosedSilenceableFailure transform::BufferLoopHoistingOp::applyToOne(
3636

3737
void transform::BufferLoopHoistingOp::getEffects(
3838
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
39-
onlyReadsHandle(getTarget(), effects);
39+
onlyReadsHandle(getTargetMutable(), effects);
4040
modifiesPayload(effects);
4141
}
4242

@@ -110,7 +110,7 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
110110

111111
void transform::EliminateEmptyTensorsOp::getEffects(
112112
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
113-
onlyReadsHandle(getTarget(), effects);
113+
onlyReadsHandle(getTargetMutable(), effects);
114114
modifiesPayload(effects);
115115
}
116116

mlir/lib/Dialect/Func/TransformOps/FuncTransformOps.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,14 +216,14 @@ LogicalResult transform::CastAndCallOp::verify() {
216216

217217
void transform::CastAndCallOp::getEffects(
218218
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
219-
transform::onlyReadsHandle(getInsertionPoint(), effects);
219+
transform::onlyReadsHandle(getInsertionPointMutable(), effects);
220220
if (getInputs())
221-
transform::onlyReadsHandle(getInputs(), effects);
221+
transform::onlyReadsHandle(getInputsMutable(), effects);
222222
if (getOutputs())
223-
transform::onlyReadsHandle(getOutputs(), effects);
223+
transform::onlyReadsHandle(getOutputsMutable(), effects);
224224
if (getFunction())
225-
transform::onlyReadsHandle(getFunction(), effects);
226-
transform::producesHandle(getResult(), effects);
225+
transform::onlyReadsHandle(getFunctionMutable(), effects);
226+
transform::producesHandle(getOperation()->getOpResults(), effects);
227227
transform::modifiesPayload(effects);
228228
}
229229

0 commit comments

Comments
 (0)