Skip to content

Commit d3aecb2

Browse files
committed
Move helpers to MemRefUtils
1 parent 246f8b3 commit d3aecb2

File tree

3 files changed

+58
-44
lines changed

3 files changed

+58
-44
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ namespace mlir {
2222

2323
class MemRefType;
2424

25+
/// A value with a memref type.
26+
using MemrefValue = TypedValue<BaseMemRefType>;
27+
2528
namespace memref {
2629

2730
/// Returns true, if the memref type has static shapes and represents a
@@ -93,6 +96,20 @@ computeStridesIRBlock(Location loc, OpBuilder &builder,
9396
return computeSuffixProductIRBlock(loc, builder, sizes);
9497
}
9598

99+
/// Walk up the source chain until an operation that changes/defines the view of
100+
/// memory is found (i.e. skip operations that alias the entire view).
101+
MemrefValue skipFullyAliasingOperations(MemrefValue source);
102+
103+
/// Checks if two (memref) values are the same or are statically known to alias
104+
/// the same region of memory.
105+
inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
106+
return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
107+
}
108+
109+
/// Walk up the source chain until something an op other than a `memref.subview`
110+
/// or `memref.cast` is found.
111+
MemrefValue skipSubViewsAndCasts(MemrefValue source);
112+
96113
} // namespace memref
97114
} // namespace mlir
98115

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,5 +178,35 @@ computeSuffixProductIRBlock(Location loc, OpBuilder &builder,
178178
return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit);
179179
}
180180

181+
MemrefValue skipFullyAliasingOperations(MemrefValue source) {
182+
while (auto op = source.getDefiningOp()) {
183+
if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
184+
subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
185+
// A `memref.subview` with an all zero offset, and all unit strides, still
186+
// points to the same memory.
187+
source = cast<MemrefValue>(subViewOp.getSource());
188+
} else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
189+
// A `memref.cast` still points to the same memory.
190+
source = castOp.getSource();
191+
} else {
192+
return source;
193+
}
194+
}
195+
return source;
196+
}
197+
198+
MemrefValue skipSubViewsAndCasts(MemrefValue source) {
199+
while (auto op = source.getDefiningOp()) {
200+
if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
201+
source = cast<MemrefValue>(subView.getSource());
202+
} else if (auto cast = dyn_cast<memref::CastOp>(op)) {
203+
source = cast.getSource();
204+
} else {
205+
return source;
206+
}
207+
}
208+
return source;
209+
}
210+
181211
} // namespace memref
182212
} // namespace mlir

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 11 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1718
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1819
#include "mlir/Dialect/Utils/IndexingUtils.h"
1920
#include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -88,46 +89,6 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
8889
return false;
8990
}
9091

91-
/// Walk up the source chain until an operation that changes/defines the view of
92-
/// memory is found (i.e. skip operations that alias the entire view).
93-
Value skipFullyAliasingOperations(Value source) {
94-
while (auto op = source.getDefiningOp()) {
95-
if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
96-
subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
97-
// A `memref.subview` with an all zero offset, and all unit strides, still
98-
// points to the same memory.
99-
source = subViewOp.getSource();
100-
} else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
101-
// A `memref.cast` still points to the same memory.
102-
source = castOp.getSource();
103-
} else {
104-
return source;
105-
}
106-
}
107-
return source;
108-
}
109-
110-
/// Checks if two (memref) values are the same or are statically known to alias
111-
/// the same region of memory.
112-
bool isSameViewOrTrivialAlias(Value a, Value b) {
113-
return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
114-
}
115-
116-
/// Walk up the source chain until something an op other than a `memref.subview`
117-
/// or `memref.cast` is found.
118-
Value skipSubViewsAndCasts(Value source) {
119-
while (auto op = source.getDefiningOp()) {
120-
if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
121-
source = subView.getSource();
122-
} else if (auto cast = dyn_cast<memref::CastOp>(op)) {
123-
source = cast.getSource();
124-
} else {
125-
return source;
126-
}
127-
}
128-
return source;
129-
}
130-
13192
/// For transfer_write to overwrite fully another transfer_write must:
13293
/// 1. Access the same memref with the same indices and vector type.
13394
/// 2. Post-dominate the other transfer_write operation.
@@ -144,7 +105,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
144105
<< "\n");
145106
llvm::SmallVector<Operation *, 8> blockingAccesses;
146107
Operation *firstOverwriteCandidate = nullptr;
147-
Value source = skipSubViewsAndCasts(write.getSource());
108+
Value source =
109+
memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
148110
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
149111
source.getUsers().end());
150112
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -163,7 +125,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
163125
continue;
164126
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
165127
// Check candidate that can override the store.
166-
if (isSameViewOrTrivialAlias(nextWrite.getSource(), write.getSource()) &&
128+
if (memref::isSameViewOrTrivialAlias(
129+
cast<MemrefValue>(nextWrite.getSource()),
130+
cast<MemrefValue>(write.getSource())) &&
167131
checkSameValueWAW(nextWrite, write) &&
168132
postDominators.postDominates(nextWrite, write)) {
169133
if (firstOverwriteCandidate == nullptr ||
@@ -228,7 +192,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
228192
<< "\n");
229193
SmallVector<Operation *, 8> blockingWrites;
230194
vector::TransferWriteOp lastwrite = nullptr;
231-
Value source = skipSubViewsAndCasts(read.getSource());
195+
Value source =
196+
memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
232197
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
233198
source.getUsers().end());
234199
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -251,7 +216,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
251216
cast<VectorTransferOpInterface>(read.getOperation()),
252217
/*testDynamicValueUsingBounds=*/true))
253218
continue;
254-
if (isSameViewOrTrivialAlias(read.getSource(), write.getSource()) &&
219+
if (memref::isSameViewOrTrivialAlias(
220+
cast<MemrefValue>(read.getSource()),
221+
cast<MemrefValue>(write.getSource())) &&
255222
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
256223
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
257224
lastwrite = write;

0 commit comments

Comments
 (0)