14
14
#include " mlir/Dialect/Affine/IR/AffineOps.h"
15
15
#include " mlir/Dialect/Arith/IR/Arith.h"
16
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
+ #include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
17
18
#include " mlir/Dialect/Tensor/IR/Tensor.h"
18
19
#include " mlir/Dialect/Utils/IndexingUtils.h"
19
20
#include " mlir/Dialect/Vector/IR/VectorOps.h"
@@ -88,46 +89,6 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
88
89
return false ;
89
90
}
90
91
91
- // / Walk up the source chain until an operation that changes/defines the view of
92
- // / memory is found (i.e. skip operations that alias the entire view).
93
- Value skipFullyAliasingOperations (Value source) {
94
- while (auto op = source.getDefiningOp ()) {
95
- if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
96
- subViewOp && subViewOp.hasZeroOffset () && subViewOp.hasUnitStride ()) {
97
- // A `memref.subview` with an all zero offset, and all unit strides, still
98
- // points to the same memory.
99
- source = subViewOp.getSource ();
100
- } else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
101
- // A `memref.cast` still points to the same memory.
102
- source = castOp.getSource ();
103
- } else {
104
- return source;
105
- }
106
- }
107
- return source;
108
- }
109
-
110
- // / Checks if two (memref) values are the same or are statically known to alias
111
- // / the same region of memory.
112
- bool isSameViewOrTrivialAlias (Value a, Value b) {
113
- return skipFullyAliasingOperations (a) == skipFullyAliasingOperations (b);
114
- }
115
-
116
- // / Walk up the source chain until something an op other than a `memref.subview`
117
- // / or `memref.cast` is found.
118
- Value skipSubViewsAndCasts (Value source) {
119
- while (auto op = source.getDefiningOp ()) {
120
- if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
121
- source = subView.getSource ();
122
- } else if (auto cast = dyn_cast<memref::CastOp>(op)) {
123
- source = cast.getSource ();
124
- } else {
125
- return source;
126
- }
127
- }
128
- return source;
129
- }
130
-
131
92
// / For transfer_write to overwrite fully another transfer_write must:
132
93
// / 1. Access the same memref with the same indices and vector type.
133
94
// / 2. Post-dominate the other transfer_write operation.
@@ -144,7 +105,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
144
105
<< " \n " );
145
106
llvm::SmallVector<Operation *, 8 > blockingAccesses;
146
107
Operation *firstOverwriteCandidate = nullptr ;
147
- Value source = skipSubViewsAndCasts (write.getSource ());
108
+ Value source =
109
+ memref::skipSubViewsAndCasts (cast<MemrefValue>(write.getSource ()));
148
110
llvm::SmallVector<Operation *, 32 > users (source.getUsers ().begin (),
149
111
source.getUsers ().end ());
150
112
llvm::SmallDenseSet<Operation *, 32 > processed;
@@ -163,7 +125,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
163
125
continue ;
164
126
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
165
127
// Check candidate that can override the store.
166
- if (isSameViewOrTrivialAlias (nextWrite.getSource (), write.getSource ()) &&
128
+ if (memref::isSameViewOrTrivialAlias (
129
+ cast<MemrefValue>(nextWrite.getSource ()),
130
+ cast<MemrefValue>(write.getSource ())) &&
167
131
checkSameValueWAW (nextWrite, write) &&
168
132
postDominators.postDominates (nextWrite, write)) {
169
133
if (firstOverwriteCandidate == nullptr ||
@@ -228,7 +192,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
228
192
<< " \n " );
229
193
SmallVector<Operation *, 8 > blockingWrites;
230
194
vector::TransferWriteOp lastwrite = nullptr ;
231
- Value source = skipSubViewsAndCasts (read.getSource ());
195
+ Value source =
196
+ memref::skipSubViewsAndCasts (cast<MemrefValue>(read.getSource ()));
232
197
llvm::SmallVector<Operation *, 32 > users (source.getUsers ().begin (),
233
198
source.getUsers ().end ());
234
199
llvm::SmallDenseSet<Operation *, 32 > processed;
@@ -251,7 +216,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
251
216
cast<VectorTransferOpInterface>(read.getOperation ()),
252
217
/* testDynamicValueUsingBounds=*/ true ))
253
218
continue ;
254
- if (isSameViewOrTrivialAlias (read.getSource (), write.getSource ()) &&
219
+ if (memref::isSameViewOrTrivialAlias (
220
+ cast<MemrefValue>(read.getSource ()),
221
+ cast<MemrefValue>(write.getSource ())) &&
255
222
dominators.dominates (write, read) && checkSameValueRAW (write, read)) {
256
223
if (lastwrite == nullptr || dominators.dominates (lastwrite, write))
257
224
lastwrite = write;
0 commit comments