Skip to content

Commit 93b80d7

Browse files
committed
[mlir][vector] Teach TransferOptimization look through trivial aliases
This allows `TransferOptimization` to eliminate and forward stores that are to trivial aliases (rather than just to identical memref values). A trivial aliases is (currently) defined as: 1. A `memref.cast` 2. A `memref.subview` with a zero offset and unit strides 3. A chain of 1 and 2
1 parent 8cee94e commit 93b80d7

File tree

2 files changed

+79
-19
lines changed

2 files changed

+79
-19
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 48 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,46 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
8888
return false;
8989
}
9090

91+
/// Walk up the source chain until an operation that changes/defines the view of
92+
/// memory is found (i.e. skip operations that alias the entire view).
93+
Value skipFullyAliasingOperations(Value source) {
94+
while (auto op = source.getDefiningOp()) {
95+
if (auto subViewOp = dyn_cast<memref::SubViewOp>(op);
96+
subViewOp && subViewOp.hasZeroOffset() && subViewOp.hasUnitStride()) {
97+
// A `memref.subview` with an all zero offset, and all unit strides, still
98+
// points to the same memory.
99+
source = subViewOp.getSource();
100+
} else if (auto castOp = dyn_cast<memref::CastOp>(op)) {
101+
// A `memref.cast` still points to the same memory.
102+
source = castOp.getSource();
103+
} else {
104+
return source;
105+
}
106+
}
107+
return source;
108+
}
109+
110+
/// Checks if two (memref) values are are the same, or are statically known to
111+
/// alias the same region of memory.
112+
bool isSameViewOrTrivialAlias(Value a, Value b) {
113+
return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
114+
}
115+
116+
/// Walk up the source chain until something an op other than a `memref.subview`
117+
/// or `memref.cast` is found.
118+
Value skipSubViewsAndCasts(Value source) {
119+
while (auto op = source.getDefiningOp()) {
120+
if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
121+
source = subView.getSource();
122+
} else if (auto cast = dyn_cast<memref::CastOp>(op)) {
123+
source = cast.getSource();
124+
} else {
125+
return source;
126+
}
127+
}
128+
return source;
129+
}
130+
91131
/// For transfer_write to overwrite fully another transfer_write must:
92132
/// 1. Access the same memref with the same indices and vector type.
93133
/// 2. Post-dominate the other transfer_write operation.
@@ -104,10 +144,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
104144
<< "\n");
105145
llvm::SmallVector<Operation *, 8> blockingAccesses;
106146
Operation *firstOverwriteCandidate = nullptr;
107-
Value source = write.getSource();
108-
// Skip subview ops.
109-
while (auto subView = source.getDefiningOp<memref::SubViewOp>())
110-
source = subView.getSource();
147+
Value source = skipSubViewsAndCasts(write.getSource());
111148
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
112149
source.getUsers().end());
113150
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -116,8 +153,8 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
116153
// If the user has already been processed skip.
117154
if (!processed.insert(user).second)
118155
continue;
119-
if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
120-
users.append(subView->getUsers().begin(), subView->getUsers().end());
156+
if (isa<memref::SubViewOp, memref::CastOp>(user)) {
157+
users.append(user->getUsers().begin(), user->getUsers().end());
121158
continue;
122159
}
123160
if (isMemoryEffectFree(user))
@@ -126,7 +163,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
126163
continue;
127164
if (auto nextWrite = dyn_cast<vector::TransferWriteOp>(user)) {
128165
// Check candidate that can override the store.
129-
if (write.getSource() == nextWrite.getSource() &&
166+
if (isSameViewOrTrivialAlias(nextWrite.getSource(), write.getSource()) &&
130167
checkSameValueWAW(nextWrite, write) &&
131168
postDominators.postDominates(nextWrite, write)) {
132169
if (firstOverwriteCandidate == nullptr ||
@@ -191,10 +228,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
191228
<< "\n");
192229
SmallVector<Operation *, 8> blockingWrites;
193230
vector::TransferWriteOp lastwrite = nullptr;
194-
Value source = read.getSource();
195-
// Skip subview ops.
196-
while (auto subView = source.getDefiningOp<memref::SubViewOp>())
197-
source = subView.getSource();
231+
Value source = skipSubViewsAndCasts(read.getSource());
198232
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
199233
source.getUsers().end());
200234
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -203,12 +237,8 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
203237
// If the user has already been processed skip.
204238
if (!processed.insert(user).second)
205239
continue;
206-
if (auto subView = dyn_cast<memref::SubViewOp>(user)) {
207-
users.append(subView->getUsers().begin(), subView->getUsers().end());
208-
continue;
209-
}
210-
if (auto collapsed = dyn_cast<memref::CollapseShapeOp>(user)) {
211-
users.append(collapsed->getUsers().begin(), collapsed->getUsers().end());
240+
if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
241+
users.append(user->getUsers().begin(), user->getUsers().end());
212242
continue;
213243
}
214244
if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user))
@@ -221,7 +251,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
221251
cast<VectorTransferOpInterface>(read.getOperation()),
222252
/*testDynamicValueUsingBounds=*/true))
223253
continue;
224-
if (write.getSource() == read.getSource() &&
254+
if (isSameViewOrTrivialAlias(read.getSource(), write.getSource()) &&
225255
dominators.dominates(write, read) && checkSameValueRAW(write, read)) {
226256
if (lastwrite == nullptr || dominators.dominates(lastwrite, write))
227257
lastwrite = write;

mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
222222
// `vector.transfer_write` would not be safe:
223223
// %1 = vector.transfer_read %subview
224224
// vector.transfer_write %1, %alloca
225-
// vector.transfer_write %vec, %collapse_shape
225+
// vector.transfer_write %vec, %collapse_shape
226226
// %2 = vector.transfer_read %alloca
227227
// vector.transfer_write %1, %subview
228228
// Indeed, %alloca and %collapse_shape alias and hence %2 != %1. Instead, the
@@ -360,3 +360,33 @@ func.func @forward_dead_store_dynamic_non_overlap_trailing_dim(
360360
vector.transfer_write %x, %buffer[%i0, %i0] {in_bounds = [true]} : vector<4xf32>, memref<?x?xf32>
361361
return
362362
}
363+
364+
// Here each read/write is to a different subview, but they all point to exact
365+
// same bit of memory (just through casts and subviews with unit strides and
366+
// zero offsets).
367+
// CHECK-LABEL: func @forward_and_eliminate_stores_through_trivial_aliases
368+
// CHECK-NOT: vector.transfer_write
369+
// CHECK-NOT: vector.transfer_read
370+
// CHECK: scf.for
371+
// CHECK: }
372+
// CHECK: vector.transfer_write
373+
// CHECK: return
374+
func.func @forward_and_eliminate_stores_through_trivial_aliases(
375+
%buffer : memref<?x?xf32>, %vec: vector<[8]x[8]xf32>, %size: index, %a_size: index, %another_size: index
376+
) {
377+
%c0 = arith.constant 0 : index
378+
%c1 = arith.constant 1 : index
379+
%c32 = arith.constant 32 : index
380+
%cst = arith.constant 0.0 : f32
381+
vector.transfer_write %vec, %buffer[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32>
382+
%direct_subview = memref.subview %buffer[0, 0] [%a_size, %a_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
383+
%cast = memref.cast %direct_subview : memref<?x?xf32, strided<[?, 1], offset: ?>> to memref<?x?xf32>
384+
%subview_of_cast = memref.subview %cast[0, 0] [%another_size, %another_size] [1, 1] : memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
385+
%21 = vector.transfer_read %direct_subview[%c0, %c0], %cst {in_bounds = [true, true]} : memref<?x?xf32, strided<[?, 1], offset: ?>>, vector<[8]x[8]xf32>
386+
%23 = scf.for %arg2 = %c0 to %c32 step %c1 iter_args(%arg3 = %21) -> (vector<[8]x[8]xf32>) {
387+
%24 = arith.addf %arg3, %arg3 : vector<[8]x[8]xf32>
388+
scf.yield %24 : vector<[8]x[8]xf32>
389+
}
390+
vector.transfer_write %23, %subview_of_cast[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf32>, memref<?x?xf32, strided<[?, 1], offset: ?>>
391+
return
392+
}

0 commit comments

Comments
 (0)