-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][side effect] refactor(*): Include more precise side effects #94213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir @llvm/pr-subscribers-mlir-affine Author: donald chen (cxy-1993) ChangesThis patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperands the operation reads or writes, rather than just recording the reading and writing of values. Full diff: https://github.com/llvm/llvm-project/pull/94213.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index f070d04886190..5c75e102c3d40 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -107,6 +107,9 @@ class AffineDmaStartOp
/// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
+ OpOperand &getSrcMemRefMutable() {
+ return getOperation()->getOpOperand(getSrcMemRefOperandIndex());
+ }
MemRefType getSrcMemRefType() {
return cast<MemRefType>(getSrcMemRef().getType());
}
@@ -117,7 +120,8 @@ class AffineDmaStartOp
/// Returns the affine map used to access the source memref.
AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
AffineMapAttr getSrcMapAttr() {
- return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
+ return cast<AffineMapAttr>(
+ *(*this)->getInherentAttr(getSrcMapAttrStrName()));
}
/// Returns the source memref affine map indices for this DMA operation.
@@ -139,6 +143,9 @@ class AffineDmaStartOp
/// Returns the destination MemRefType for this DMA operation.
Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
+ OpOperand &getDstMemRefMutable() {
+ return getOperation()->getOpOperand(getDstMemRefOperandIndex());
+ }
MemRefType getDstMemRefType() {
return cast<MemRefType>(getDstMemRef().getType());
}
@@ -156,7 +163,8 @@ class AffineDmaStartOp
/// Returns the affine map used to access the destination memref.
AffineMap getDstMap() { return getDstMapAttr().getValue(); }
AffineMapAttr getDstMapAttr() {
- return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
+ return cast<AffineMapAttr>(
+ *(*this)->getInherentAttr(getDstMapAttrStrName()));
}
/// Returns the destination memref indices for this DMA operation.
@@ -173,6 +181,9 @@ class AffineDmaStartOp
/// Returns the Tag MemRef for this DMA operation.
Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
+ OpOperand &getTagMemRefMutable() {
+ return getOperation()->getOpOperand(getTagMemRefOperandIndex());
+ }
MemRefType getTagMemRefType() {
return cast<MemRefType>(getTagMemRef().getType());
}
@@ -185,7 +196,8 @@ class AffineDmaStartOp
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
- return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+ return cast<AffineMapAttr>(
+ *(*this)->getInherentAttr(getTagMapAttrStrName()));
}
/// Returns the tag memref indices for this DMA operation.
@@ -300,6 +312,7 @@ class AffineDmaWaitOp
/// Returns the Tag MemRef associated with the DMA operation being waited on.
Value getTagMemRef() { return getOperand(0); }
+ OpOperand &getTagMemRefMutable() { return getOperation()->getOpOperand(0); }
MemRefType getTagMemRefType() {
return cast<MemRefType>(getTagMemRef().getType());
}
@@ -307,7 +320,8 @@ class AffineDmaWaitOp
/// Returns the affine map used to access the tag memref.
AffineMap getTagMap() { return getTagMapAttr().getValue(); }
AffineMapAttr getTagMapAttr() {
- return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+ return cast<AffineMapAttr>(
+ *(*this)->getInherentAttr(getTagMapAttrStrName()));
}
/// Returns the tag memref index for this DMA operation.
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index 63e6ed059deb1..0606bfd28503a 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -706,6 +706,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
let extraClassDeclaration = [{
// Returns the source MemRefType for this DMA operation.
Value getSrcMemRef() { return getOperand(0); }
+ OpOperand &getSrcMemRefMutable() { return getOperation()->getOpOperand(0); }
// Returns the rank (number of indices) of the source MemRefType.
unsigned getSrcMemRefRank() {
return ::llvm::cast<MemRefType>(getSrcMemRef().getType()).getRank();
@@ -718,6 +719,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
// Returns the destination MemRefType for this DMA operations.
Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); }
+ OpOperand &getDstMemRefMutable() { return getOperation()->getOpOperand(1 + getSrcMemRefRank()); }
// Returns the rank (number of indices) of the destination MemRefType.
unsigned getDstMemRefRank() {
return ::llvm::cast<MemRefType>(getDstMemRef().getType()).getRank();
@@ -745,6 +747,9 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
Value getTagMemRef() {
return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
}
+ OpOperand &getTagMemRefMutable() {
+ return getOperation()->getOpOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1);
+ }
// Returns the rank (number of indices) of the tag MemRefType.
unsigned getTagMemRefRank() {
@@ -801,11 +806,11 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
void getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
effects) {
- effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
+ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
+ effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+ effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
SideEffects::DefaultResource::get());
}
}];
@@ -852,7 +857,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
void getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> &
effects) {
- effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+ effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
SideEffects::DefaultResource::get());
}
}];
diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
index ec4e36263bbe6..61af0acfb986e 100644
--- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
+++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h
@@ -149,11 +149,20 @@ class EffectInstance {
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value), stage(0),
effectOnFullRegion(false) {}
+ EffectInstance(EffectT *effect, OpOperand *opd,
+ Resource *resource = DefaultResource::get())
+ : effect(effect), resource(resource), value(opd), stage(0),
+ effectOnFullRegion(false) {}
EffectInstance(EffectT *effect, Value value, int stage,
bool effectOnFullRegion,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value), stage(stage),
effectOnFullRegion(effectOnFullRegion) {}
+ EffectInstance(EffectT *effect, OpOperand *opd, int stage,
+ bool effectOnFullRegion,
+ Resource *resource = DefaultResource::get())
+ : effect(effect), resource(resource), value(opd), stage(stage),
+ effectOnFullRegion(effectOnFullRegion) {}
EffectInstance(EffectT *effect, SymbolRefAttr symbol,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(symbol), stage(0),
@@ -176,12 +185,21 @@ class EffectInstance {
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value),
parameters(parameters), stage(0), effectOnFullRegion(false) {}
+ EffectInstance(EffectT *effect, OpOperand *opd, Attribute parameters,
+ Resource *resource = DefaultResource::get())
+ : effect(effect), resource(resource), value(opd), parameters(parameters),
+ stage(0), effectOnFullRegion(false) {}
EffectInstance(EffectT *effect, Value value, Attribute parameters, int stage,
bool effectOnFullRegion,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value),
parameters(parameters), stage(stage),
effectOnFullRegion(effectOnFullRegion) {}
+ EffectInstance(EffectT *effect, OpOperand *opd, Attribute parameters,
+ int stage, bool effectOnFullRegion,
+ Resource *resource = DefaultResource::get())
+ : effect(effect), resource(resource), value(opd), parameters(parameters),
+ stage(stage), effectOnFullRegion(effectOnFullRegion) {}
EffectInstance(EffectT *effect, SymbolRefAttr symbol, Attribute parameters,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(symbol),
@@ -199,7 +217,17 @@ class EffectInstance {
/// Return the value the effect is applied on, or nullptr if there isn't a
/// known value being affected.
Value getValue() const {
- return value ? llvm::dyn_cast_if_present<Value>(value) : Value();
+ if (!value || llvm::isa_and_present<SymbolRefAttr>(value)) {
+ return Value();
+ }
+ if (Value v = llvm::dyn_cast_if_present<Value>(value)) {
+ return v;
+ }
+ return cast_if_present<OpOperand *>(value)->get();
+ }
+
+ OpOperand *getOpOperand() const {
+ return value ? dyn_cast_if_present<OpOperand *>(value) : nullptr;
}
/// Return the symbol reference the effect is applied on, or nullptr if there
@@ -229,7 +257,7 @@ class EffectInstance {
Resource *resource;
/// The Symbol or Value that the effect applies to. This is optionally null.
- PointerUnion<SymbolRefAttr, Value> value;
+ PointerUnion<SymbolRefAttr, Value, OpOperand *> value;
/// Additional parameters of the effect instance. An attribute is used for
/// type-safe structured storage and context-based uniquing. Concrete effects
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 2e31487bd55a0..3efe93c300f46 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1703,11 +1703,11 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
void AffineDmaStartOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- effects.emplace_back(MemoryEffects::Read::get(), getSrcMemRef(),
+ effects.emplace_back(MemoryEffects::Read::get(), &getSrcMemRefMutable(),
SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), getDstMemRef(),
+ effects.emplace_back(MemoryEffects::Write::get(), &getDstMemRefMutable(),
SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+ effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
SideEffects::DefaultResource::get());
}
@@ -1793,7 +1793,7 @@ LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
void AffineDmaWaitOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- effects.emplace_back(MemoryEffects::Read::get(), getTagMemRef(),
+ effects.emplace_back(MemoryEffects::Read::get(), &getTagMemRefMutable(),
SideEffects::DefaultResource::get());
}
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 3b7b412842bfb..04a8ff30ee946 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -728,7 +728,7 @@ void MaterializeInDestinationOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (isa<BaseMemRefType>(getDest().getType()))
- effects.emplace_back(MemoryEffects::Write::get(), getDest(),
+ effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
SideEffects::DefaultResource::get());
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 60b911948d4a0..08259dd6597ca 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -825,7 +825,7 @@ Type GEPOp::getResultPtrElementType() {
void LoadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- effects.emplace_back(MemoryEffects::Read::get(), getAddr());
+ effects.emplace_back(MemoryEffects::Read::get(), &getAddrMutable());
// Volatile operations can have target-specific read-write effects on
// memory besides the one referred to by the pointer operand.
// Similarly, atomic operations that are monotonic or stricter cause
@@ -902,7 +902,7 @@ void LoadOp::build(OpBuilder &builder, OperationState &state, Type type,
void StoreOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- effects.emplace_back(MemoryEffects::Write::get(), getAddr());
+ effects.emplace_back(MemoryEffects::Write::get(), &getAddrMutable());
// Volatile operations can have target-specific read-write effects on
// memory besides the one referred to by the pointer operand.
// Similarly, atomic operations that are monotonic or stricter cause
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..1026d121abd17 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1128,7 +1128,8 @@ static void getGenericEffectsImpl(
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
if (linalgOp.payloadUsesValueFromOperand(&linalgOp->getOpOperand(index))) {
- effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ effects.emplace_back(MemoryEffects::Read::get(),
+ &linalgOp->getOpOperand(index), /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
}
@@ -1138,13 +1139,16 @@ static void getGenericEffectsImpl(
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInits())) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
+ unsigned operandIdx = index + inputOperandSize;
if (linalgOp.payloadUsesValueFromOperand(
- &linalgOp->getOpOperand(index + inputOperandSize))) {
- effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ &linalgOp->getOpOperand(operandIdx))) {
+ effects.emplace_back(MemoryEffects::Read::get(),
+ &linalgOp->getOpOperand(operandIdx), /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
}
- effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+ effects.emplace_back(MemoryEffects::Write::get(),
+ &linalgOp->getOpOperand(operandIdx), /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
}
@@ -2546,20 +2550,27 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
void SoftmaxOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
- for (Value operand : getDpsInputs()) {
+ SmallVector<Value> inputOperands = getDpsInputs();
+ for (auto [index, operand] : llvm::enumerate(inputOperands)) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ effects.emplace_back(MemoryEffects::Read::get(),
+ &getOperation()->getOpOperand(index), /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
}
- for (Value operand : getDpsInits()) {
+
+ unsigned inputOperandSize = inputOperands.size();
+ for (auto [index, operand] : llvm::enumerate(getDpsInits())) {
if (!llvm::isa<MemRefType>(operand.getType()))
continue;
- effects.emplace_back(MemoryEffects::Read::get(), operand, /*stage=*/0,
+ unsigned operandIdx = index + inputOperandSize;
+ effects.emplace_back(MemoryEffects::Read::get(),
+ &getOperation()->getOpOperand(operandIdx), /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
- effects.emplace_back(MemoryEffects::Write::get(), operand, /*stage=*/0,
+ effects.emplace_back(MemoryEffects::Write::get(),
+ &getOperation()->getOpOperand(operandIdx), /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58951641d33ce..f528c0a7960e7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4123,7 +4123,7 @@ void TransferReadOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (llvm::isa<MemRefType>(getShapedType()))
- effects.emplace_back(MemoryEffects::Read::get(), getSource(),
+ effects.emplace_back(MemoryEffects::Read::get(), &getSourceMutable(),
SideEffects::DefaultResource::get());
}
@@ -4497,7 +4497,7 @@ void TransferWriteOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (llvm::isa<MemRefType>(getShapedType()))
- effects.emplace_back(MemoryEffects::Write::get(), getSource(),
+ effects.emplace_back(MemoryEffects::Write::get(), &getSourceMutable(),
SideEffects::DefaultResource::get());
}
|
Related discussion: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also mention in the commit message:
- Why this is useful.
- Link to the RFC.
@@ -149,11 +149,20 @@ class EffectInstance { | |||
Resource *resource = DefaultResource::get()) | |||
: effect(effect), resource(resource), value(value), stage(0), | |||
effectOnFullRegion(false) {} | |||
EffectInstance(EffectT *effect, OpOperand *opd, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list of constructors gets longer and longer... Do you see a way to shorten it a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As previously discussed(https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243), I have added OpResult and BlockArgument, and reimplemented it using templates. Please review again
I think you also need additional variants of these functions to query the side effects: /// Returns true if `op` has an effect of type `EffectTy` on `value`. If no
/// `value` is provided, simply check if effects of the given type(s) are
/// present.
template <typename... EffectTys>
bool hasEffect(Operation *op, Value value = nullptr); |
5a7765d
to
a9f1a21
Compare
04b6a37
to
3dccee5
Compare
1be1abb
to
993137d
Compare
993137d
to
cb2df2f
Compare
Hi @matthias-springer @ftynse , please review this patch when you have time. The main changes in this patch are sideEffectInstance and some utility functions around it, and the other changes are due to function signature changes. |
This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperand/OpResult/BlockArgument the operation reads or writes, rather than just recording the reading and writing of values. This allows for convenient use of precise side effects to achieve analysis and optimization. Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
cb2df2f
to
26b715b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, do you need help merging?
Thanks for your review, I"ll merge it after ci passed. |
…lvm#94213) This patch adds more precise side effects to the current ops with memory effects, allowing us to determine which OpOperand/OpResult/BlockArgument the operation reads or writes, rather than just recording the reading and writing of values. This allows for convenient use of precise side effects to achieve analysis and optimization. Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
This patch adds more precise side effects to the current ops with memory
effects, allowing us to determine which OpOperand/OpResult/BlockArgument the
operation reads or writes, rather than just recording the reading and writing
of values. This allows for convenient use of precise side effects to achieve
analysis and optimization.
Related discussions: https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243