Skip to content

Commit 993137d

Browse files
committed
refactor(*): Include more precise side effects
This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperand/OpResult/BlockArgument the operation reads or writes, rather than just recording the reading and writing of values. This allows for convenient use of precise side effects to achieve analysis and optimization. Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
1 parent efbd64c commit 993137d

File tree

36 files changed

+628
-305
lines changed

36 files changed

+628
-305
lines changed

mlir/examples/transform/Ch2/lib/MyExtension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ void mlir::transform::ChangeCallTargetOp::getEffects(
129129
// Indicate that the `call` handle is only read by this operation because the
130130
// associated operation is not erased but rather modified in-place, so the
131131
// reference to it remains valid.
132-
onlyReadsHandle(getCall(), effects);
132+
onlyReadsHandle(getCallMutable(), effects);
133133

134134
// Indicate that the payload is modified by this operation.
135135
modifiesPayload(effects);

mlir/examples/transform/Ch3/lib/MyExtension.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void mlir::transform::ChangeCallTargetOp::getEffects(
139139
// Indicate that the `call` handle is only read by this operation because the
140140
// associated operation is not erased but rather modified in-place, so the
141141
// reference to it remains valid.
142-
onlyReadsHandle(getCall(), effects);
142+
onlyReadsHandle(getCallMutable(), effects);
143143

144144
// Indicate that the payload is modified by this operation.
145145
modifiesPayload(effects);

mlir/examples/transform/Ch4/lib/MyExtension.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,8 @@ mlir::transform::HasOperandSatisfyingOp::apply(
160160
void mlir::transform::HasOperandSatisfyingOp::getEffects(
161161
llvm::SmallVectorImpl<mlir::MemoryEffects::EffectInstance> &effects) {
162162
onlyReadsPayload(effects);
163-
onlyReadsHandle(getOp(), effects);
164-
producesHandle(getPosition(), effects);
165-
producesHandle(getResults(), effects);
163+
onlyReadsHandle(getOpMutable(), effects);
164+
producesHandle(getOperation()->getOpResults(), effects);
166165
}
167166

168167
// Verify well-formedness of the operation and emit diagnostics if it is

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ class AffineDmaStartOp
107107

108108
/// Returns the source MemRefType for this DMA operation.
109109
Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
110+
OpOperand &getSrcMemRefMutable() {
111+
return getOperation()->getOpOperand(getSrcMemRefOperandIndex());
112+
}
110113
MemRefType getSrcMemRefType() {
111114
return cast<MemRefType>(getSrcMemRef().getType());
112115
}
@@ -117,7 +120,8 @@ class AffineDmaStartOp
117120
/// Returns the affine map used to access the source memref.
118121
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
119122
AffineMapAttr getSrcMapAttr() {
120-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
123+
return cast<AffineMapAttr>(
124+
*(*this)->getInherentAttr(getSrcMapAttrStrName()));
121125
}
122126

123127
/// Returns the source memref affine map indices for this DMA operation.
@@ -139,6 +143,9 @@ class AffineDmaStartOp
139143

140144
/// Returns the destination MemRefType for this DMA operation.
141145
Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
146+
OpOperand &getDstMemRefMutable() {
147+
return getOperation()->getOpOperand(getDstMemRefOperandIndex());
148+
}
142149
MemRefType getDstMemRefType() {
143150
return cast<MemRefType>(getDstMemRef().getType());
144151
}
@@ -156,7 +163,8 @@ class AffineDmaStartOp
156163
/// Returns the affine map used to access the destination memref.
157164
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
158165
AffineMapAttr getDstMapAttr() {
159-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
166+
return cast<AffineMapAttr>(
167+
*(*this)->getInherentAttr(getDstMapAttrStrName()));
160168
}
161169

162170
/// Returns the destination memref indices for this DMA operation.
@@ -173,6 +181,9 @@ class AffineDmaStartOp
173181

174182
/// Returns the Tag MemRef for this DMA operation.
175183
Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
184+
OpOperand &getTagMemRefMutable() {
185+
return getOperation()->getOpOperand(getTagMemRefOperandIndex());
186+
}
176187
MemRefType getTagMemRefType() {
177188
return cast<MemRefType>(getTagMemRef().getType());
178189
}
@@ -185,7 +196,8 @@ class AffineDmaStartOp
185196
/// Returns the affine map used to access the tag memref.
186197
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
187198
AffineMapAttr getTagMapAttr() {
188-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
199+
return cast<AffineMapAttr>(
200+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
189201
}
190202

191203
/// Returns the tag memref indices for this DMA operation.
@@ -300,14 +312,16 @@ class AffineDmaWaitOp
300312

301313
/// Returns the Tag MemRef associated with the DMA operation being waited on.
302314
Value getTagMemRef() { return getOperand(0); }
315+
OpOperand &getTagMemRefMutable() { return getOperation()->getOpOperand(0); }
303316
MemRefType getTagMemRefType() {
304317
return cast<MemRefType>(getTagMemRef().getType());
305318
}
306319

307320
/// Returns the affine map used to access the tag memref.
308321
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
309322
AffineMapAttr getTagMapAttr() {
310-
return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
323+
return cast<AffineMapAttr>(
324+
*(*this)->getInherentAttr(getTagMapAttrStrName()));
311325
}
312326

313327
/// Returns the tag memref index for this DMA operation.

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
706706
let extraClassDeclaration = [{
707707
// Returns the source MemRefType for this DMA operation.
708708
Value getSrcMemRef() { return getOperand(0); }
709+
OpOperand &getSrcMemRefMutable() { return getOperation()->getOpOperand(0); }
709710
// Returns the rank (number of indices) of the source MemRefType.
710711
unsigned getSrcMemRefRank() {
711712
return ::llvm::cast<MemRefType>(getSrcMemRef().getType()).getRank();
@@ -718,6 +719,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
718719

719720
// Returns the destination MemRefType for this DMA operations.
720721
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
722+
OpOperand &getDstMemRefMutable() { return getOperation()->getOpOperand(1 + getSrcMemRefRank()); }
721723
// Returns the rank (number of indices) of the destination MemRefType.
722724
unsigned getDstMemRefRank() {
723725
return ::llvm::cast<MemRefType>(getDstMemRef().getType()).getRank();
@@ -745,6 +747,9 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
745747
Value getTagMemRef() {
746748
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
747749
}
750+
OpOperand &getTagMemRefMutable() {
751+
return getOperation()->getOpOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
752+
}
748753

749754
// Returns the rank (number of indices) of the tag MemRefType.
750755
unsigned getTagMemRefRank() {
@@ -801,11 +806,11 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
801806
void getEffects(
802807
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
803808
effects) {
804-
effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
809+
effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
805810
SideEffects::DefaultResource::get());
806-
effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
811+
effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
807812
SideEffects::DefaultResource::get());
808-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
813+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
809814
SideEffects::DefaultResource::get());
810815
}
811816
}];
@@ -852,7 +857,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
852857
void getEffects(
853858
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
854859
effects) {
855-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
860+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
856861
SideEffects::DefaultResource::get());
857862
}
858863
}];

mlir/include/mlir/Dialect/Transform/Interfaces/MatchInterfaces.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ class AtMostOneOpMatcherOpTrait
102102
}
103103

104104
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
105-
onlyReadsHandle(this->getOperation()->getOperands(), effects);
106-
producesHandle(this->getOperation()->getResults(), effects);
105+
onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
106+
producesHandle(this->getOperation()->getOpResults(), effects);
107107
onlyReadsPayload(effects);
108108
}
109109
};
@@ -163,8 +163,8 @@ class SingleValueMatcherOpTrait
163163
}
164164

165165
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
166-
onlyReadsHandle(this->getOperation()->getOperands(), effects);
167-
producesHandle(this->getOperation()->getResults(), effects);
166+
onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
167+
producesHandle(this->getOperation()->getOpResults(), effects);
168168
onlyReadsPayload(effects);
169169
}
170170
};

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1261,11 +1261,13 @@ struct PayloadIRResource
12611261
/// - consumes = Read + Free,
12621262
/// - produces = Allocate + Write,
12631263
/// - onlyReads = Read.
1264-
void consumesHandle(ValueRange handles,
1264+
void consumesHandle(MutableArrayRef<OpOperand> handles,
12651265
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1266-
void producesHandle(ValueRange handles,
1266+
void producesHandle(ResultRange handles,
12671267
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1268-
void onlyReadsHandle(ValueRange handles,
1268+
void producesHandle(MutableArrayRef<BlockArgument> handles,
1269+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
1270+
void onlyReadsHandle(MutableArrayRef<OpOperand> handles,
12691271
SmallVectorImpl<MemoryEffects::EffectInstance> &effects);
12701272

12711273
/// Checks whether the transform op consumes the given handle.
@@ -1296,8 +1298,8 @@ class FunctionalStyleTransformOpTrait
12961298
/// the results by allocating and writing it and reads/writes the payload IR
12971299
/// in the process.
12981300
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1299-
consumesHandle(this->getOperation()->getOperands(), effects);
1300-
producesHandle(this->getOperation()->getResults(), effects);
1301+
consumesHandle(this->getOperation()->getOpOperands(), effects);
1302+
producesHandle(this->getOperation()->getOpResults(), effects);
13011303
modifiesPayload(effects);
13021304
}
13031305

@@ -1322,8 +1324,8 @@ class NavigationTransformOpTrait
13221324
/// This op produces handles to the Payload IR without consuming the original
13231325
/// handles and without modifying the IR itself.
13241326
void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1325-
onlyReadsHandle(this->getOperation()->getOperands(), effects);
1326-
producesHandle(this->getOperation()->getResults(), effects);
1327+
onlyReadsHandle(this->getOperation()->getOpOperands(), effects);
1328+
producesHandle(this->getOperation()->getOpResults(), effects);
13271329
if (llvm::any_of(this->getOperation()->getOperandTypes(), [](Type t) {
13281330
return isa<TransformHandleTypeInterface,
13291331
TransformValueHandleTypeInterface>(t);

mlir/include/mlir/Interfaces/SideEffectInterfaces.h

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,19 @@ class EffectInstance {
145145
Resource *resource = DefaultResource::get())
146146
: effect(effect), resource(resource), stage(stage),
147147
effectOnFullRegion(effectOnFullRegion) {}
148-
EffectInstance(EffectT *effect, Value value,
148+
template <typename T,
149+
std::enable_if_t<
150+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
151+
bool> = true>
152+
EffectInstance(EffectT *effect, T value,
149153
Resource *resource = DefaultResource::get())
150154
: effect(effect), resource(resource), value(value), stage(0),
151155
effectOnFullRegion(false) {}
152-
EffectInstance(EffectT *effect, Value value, int stage,
153-
bool effectOnFullRegion,
156+
template <typename T,
157+
std::enable_if_t<
158+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
159+
bool> = true>
160+
EffectInstance(EffectT *effect, T value, int stage, bool effectOnFullRegion,
154161
Resource *resource = DefaultResource::get())
155162
: effect(effect), resource(resource), value(value), stage(stage),
156163
effectOnFullRegion(effectOnFullRegion) {}
@@ -172,11 +179,19 @@ class EffectInstance {
172179
Resource *resource = DefaultResource::get())
173180
: effect(effect), resource(resource), parameters(parameters),
174181
stage(stage), effectOnFullRegion(effectOnFullRegion) {}
175-
EffectInstance(EffectT *effect, Value value, Attribute parameters,
182+
template <typename T,
183+
std::enable_if_t<
184+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
185+
bool> = true>
186+
EffectInstance(EffectT *effect, T value, Attribute parameters,
176187
Resource *resource = DefaultResource::get())
177188
: effect(effect), resource(resource), value(value),
178189
parameters(parameters), stage(0), effectOnFullRegion(false) {}
179-
EffectInstance(EffectT *effect, Value value, Attribute parameters, int stage,
190+
template <typename T,
191+
std::enable_if_t<
192+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
193+
bool> = true>
194+
EffectInstance(EffectT *effect, T value, Attribute parameters, int stage,
180195
bool effectOnFullRegion,
181196
Resource *resource = DefaultResource::get())
182197
: effect(effect), resource(resource), value(value),
@@ -199,7 +214,26 @@ class EffectInstance {
199214
/// Return the value the effect is applied on, or nullptr if there isn't a
200215
/// known value being affected.
201216
Value getValue() const {
202-
return value ? llvm::dyn_cast_if_present<Value>(value) : Value();
217+
if (!value || llvm::isa_and_present<SymbolRefAttr>(value)) {
218+
return Value();
219+
}
220+
if (OpOperand *operand = llvm::dyn_cast_if_present<OpOperand *>(value)) {
221+
return operand->get();
222+
}
223+
if (OpResult result = llvm::dyn_cast_if_present<OpResult>(value)) {
224+
return result;
225+
}
226+
return cast_if_present<BlockArgument>(value);
227+
}
228+
229+
/// Returns the OpOperand effect is applied on, or nullptr if there isn't a
230+
/// known value being effected.
231+
template <typename T,
232+
std::enable_if_t<
233+
llvm::is_one_of<T, OpOperand *, OpResult, BlockArgument>::value,
234+
bool> = true>
235+
T getEffectValue() const {
236+
return value ? dyn_cast_if_present<T>(value) : nullptr;
203237
}
204238

205239
/// Return the symbol reference the effect is applied on, or nullptr if there
@@ -228,8 +262,9 @@ class EffectInstance {
228262
/// The resource that the given value resides in.
229263
Resource *resource;
230264

231-
/// The Symbol or Value that the effect applies to. This is optionally null.
232-
PointerUnion<SymbolRefAttr, Value> value;
265+
/// The Symbol, OpOperand, OpResult or BlockArgument that the effect applies
266+
/// to. This is optionally null.
267+
PointerUnion<SymbolRefAttr, OpOperand *, OpResult, BlockArgument> value;
233268

234269
/// Additional parameters of the effect instance. An attribute is used for
235270
/// type-safe structured storage and context-based uniquing. Concrete effects
@@ -348,17 +383,32 @@ struct Write : public Effect::Base<Write> {};
348383
// SideEffect Utilities
349384
//===----------------------------------------------------------------------===//
350385

386+
/// Returns true if `op` has only an effect of type `EffectTy`.
387+
template <typename EffectTy>
388+
bool hasSingleEffect(Operation *op);
389+
351390
/// Returns true if `op` has only an effect of type `EffectTy` (and of no other
352-
/// type) on `value`. If no value is provided, simply check if effects of that
353-
/// type and only of that type are present.
391+
/// type) on `value`.
354392
template <typename EffectTy>
355-
bool hasSingleEffect(Operation *op, Value value = nullptr);
393+
bool hasSingleEffect(Operation *op, Value value);
394+
395+
/// Returns true if `op` has only an effect of type `EffectTy` (and of no other
396+
/// type) on `value` of type `ValueTy`.
397+
template <typename ValueTy, typename EffectTy>
398+
bool hasSingleEffect(Operation *op, ValueTy value);
356399

357-
/// Returns true if `op` has an effect of type `EffectTy` on `value`. If no
358-
/// `value` is provided, simply check if effects of the given type(s) are
359-
/// present.
400+
/// Returns true if `op` has an effect of type `EffectTy`.
360401
template <typename... EffectTys>
361-
bool hasEffect(Operation *op, Value value = nullptr);
402+
bool hasEffect(Operation *op);
403+
404+
/// Returns true if `op` has an effect of type `EffectTy` on `value`.
405+
template <typename... EffectTys>
406+
bool hasEffect(Operation *op, Value value);
407+
408+
/// Returns true if `op` has an effect of type `EffectTy` on `value` of type
409+
/// `ValueTy`.
410+
template <typename ValueTy, typename... EffectTys>
411+
bool hasEffect(Operation *op, ValueTy value);
362412

363413
/// Return true if the given operation is unused, and has no side effects on
364414
/// memory that prevent erasing.

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1703,11 +1703,11 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
17031703
void AffineDmaStartOp::getEffects(
17041704
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
17051705
&effects) {
1706-
effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
1706+
effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
17071707
SideEffects::DefaultResource::get());
1708-
effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
1708+
effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
17091709
SideEffects::DefaultResource::get());
1710-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1710+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
17111711
SideEffects::DefaultResource::get());
17121712
}
17131713

@@ -1793,7 +1793,7 @@ LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
17931793
void AffineDmaWaitOp::getEffects(
17941794
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
17951795
&effects) {
1796-
effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
1796+
effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
17971797
SideEffects::DefaultResource::get());
17981798
}
17991799

mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,9 @@ SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter,
142142

143143
void SimplifyBoundedAffineOpsOp::getEffects(
144144
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
145-
consumesHandle(getTarget(), effects);
146-
for (Value v : getBoundedValues())
147-
onlyReadsHandle(v, effects);
145+
consumesHandle(getTargetMutable(), effects);
146+
for (OpOperand &operand : getBoundedValuesMutable())
147+
onlyReadsHandle(operand, effects);
148148
modifiesPayload(effects);
149149
}
150150

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ void MaterializeInDestinationOp::getEffects(
728728
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
729729
&effects) {
730730
if (isa<BaseMemRefType>(getDest().getType()))
731-
effects.emplace_back(MemoryEffects::Write::get(), getDest(),
731+
effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
732732
SideEffects::DefaultResource::get());
733733
}
734734

0 commit comments

Comments
 (0)