Skip to content

Commit 4e2efea

Browse files
authored
[mlir][vector] Add all view-like ops to transfer flow opt (#110521)
`vector.transfer_*` folding and forwarding currently does not take into account reshaping view-like memref ops (expand and collapse shape), leading to potentially invalid store folding or value forwarding. This patch adds tracking for those (and other) view-like ops. It is still possible to design operations that alias memrefs without being a view (e.g. memref in the iter_args of an `scf.for`), so these patterns may still need revisiting in the future.
1 parent 9cd5e5c commit 4e2efea

File tree

4 files changed

+107
-18
lines changed

4 files changed

+107
-18
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ inline bool isSameViewOrTrivialAlias(MemrefValue a, MemrefValue b) {
106106
return skipFullyAliasingOperations(a) == skipFullyAliasingOperations(b);
107107
}
108108

109-
/// Walk up the source chain until something an op other than a `memref.subview`
110-
/// or `memref.cast` is found.
111-
MemrefValue skipSubViewsAndCasts(MemrefValue source);
109+
/// Walk up the source chain until we find an operation that is not a view of
110+
/// the source memref (i.e. implements ViewLikeOpInterface).
111+
MemrefValue skipViewLikeOps(MemrefValue source);
112112

113113
} // namespace memref
114114
} // namespace mlir

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Arith/Utils/Utils.h"
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "mlir/Interfaces/ViewLikeInterface.h"
1819
#include "llvm/ADT/STLExtras.h"
1920

2021
namespace mlir {
@@ -193,15 +194,13 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) {
193194
return source;
194195
}
195196

196-
MemrefValue skipSubViewsAndCasts(MemrefValue source) {
197+
MemrefValue skipViewLikeOps(MemrefValue source) {
197198
while (auto op = source.getDefiningOp()) {
198-
if (auto subView = dyn_cast<memref::SubViewOp>(op)) {
199-
source = cast<MemrefValue>(subView.getSource());
200-
} else if (auto cast = dyn_cast<memref::CastOp>(op)) {
201-
source = cast.getSource();
202-
} else {
203-
return source;
199+
if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) {
200+
source = cast<MemrefValue>(viewLike.getViewSource());
201+
continue;
204202
}
203+
return source;
205204
}
206205
return source;
207206
}

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
105105
<< "\n");
106106
llvm::SmallVector<Operation *, 8> blockingAccesses;
107107
Operation *firstOverwriteCandidate = nullptr;
108-
Value source =
109-
memref::skipSubViewsAndCasts(cast<MemrefValue>(write.getSource()));
108+
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getSource()));
110109
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
111110
source.getUsers().end());
112111
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -115,7 +114,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
115114
// If the user has already been processed skip.
116115
if (!processed.insert(user).second)
117116
continue;
118-
if (isa<memref::SubViewOp, memref::CastOp>(user)) {
117+
if (isa<ViewLikeOpInterface>(user)) {
119118
users.append(user->getUsers().begin(), user->getUsers().end());
120119
continue;
121120
}
@@ -192,8 +191,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
192191
<< "\n");
193192
SmallVector<Operation *, 8> blockingWrites;
194193
vector::TransferWriteOp lastwrite = nullptr;
195-
Value source =
196-
memref::skipSubViewsAndCasts(cast<MemrefValue>(read.getSource()));
194+
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getSource()));
197195
llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(),
198196
source.getUsers().end());
199197
llvm::SmallDenseSet<Operation *, 32> processed;
@@ -202,7 +200,7 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
202200
// If the user has already been processed skip.
203201
if (!processed.insert(user).second)
204202
continue;
205-
if (isa<memref::SubViewOp, memref::CollapseShapeOp, memref::CastOp>(user)) {
203+
if (isa<ViewLikeOpInterface>(user)) {
206204
users.append(user->getUsers().begin(), user->getUsers().end());
207205
continue;
208206
}

mlir/test/Dialect/Vector/vector-transferop-opt.mlir

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ func.func @forward_dead_store_negative(%arg0: i1, %arg1 : memref<4x4xf32>,
229229
// final `vector.transfer_write` should be preserved as:
230230
// vector.transfer_write %2, %subview
231231

232-
// CHECK-LABEL: func.func @collapse_shape
232+
// CHECK-LABEL: func.func @collapse_shape_and_read_from_source
233233
// CHECK: scf.for {{.*}} {
234234
// CHECK: vector.transfer_read
235235
// CHECK: vector.transfer_write
236236
// CHECK: vector.transfer_write
237237
// CHECK: vector.transfer_read
238238
// CHECK: vector.transfer_write
239239

240-
func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
240+
func.func @collapse_shape_and_read_from_source(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
241241
%c0_i32 = arith.constant 0 : i32
242242
%c0 = arith.constant 0 : index
243243
%c4 = arith.constant 4 : index
@@ -257,6 +257,98 @@ func.func @collapse_shape(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
257257
return
258258
}
259259

260+
// The same regression test for expand_shape.
261+
262+
// CHECK-LABEL: func.func @expand_shape_and_read_from_source
263+
// CHECK: scf.for {{.*}} {
264+
// CHECK: vector.transfer_read
265+
// CHECK: vector.transfer_write
266+
// CHECK: vector.transfer_write
267+
// CHECK: vector.transfer_read
268+
// CHECK: vector.transfer_write
269+
270+
func.func @expand_shape_and_read_from_source(%in_0: memref<20xi32>, %vec: vector<1x4x1xi32>) {
271+
%c0_i32 = arith.constant 0 : i32
272+
%c0 = arith.constant 0 : index
273+
%c4 = arith.constant 4 : index
274+
%c20 = arith.constant 20 : index
275+
276+
%alloca = memref.alloca() {alignment = 64 : i64} : memref<4xi32>
277+
%expand_shape = memref.expand_shape %alloca [[0, 1, 2]] output_shape [1, 4, 1] : memref<4xi32> into memref<1x4x1xi32>
278+
scf.for %arg0 = %c0 to %c20 step %c4 {
279+
%subview = memref.subview %in_0[%arg0] [4] [1] : memref<20xi32> to memref<4xi32, strided<[1], offset: ?>>
280+
%1 = vector.transfer_read %subview[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32, strided<[1], offset: ?>>, vector<4xi32>
281+
// $alloca and $expand_shape alias
282+
vector.transfer_write %1, %alloca[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
283+
vector.transfer_write %vec, %expand_shape[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
284+
%2 = vector.transfer_read %alloca[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32>, vector<4xi32>
285+
vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32, strided<[1], offset: ?>>
286+
}
287+
return
288+
}
289+
290+
// The same regression test, but the initial write is to the collapsed memref,
291+
// and the subsequent unforwardable read is from the collapse shape.
292+
293+
// CHECK-LABEL: func.func @collapse_shape_and_read_from_collapse
294+
// CHECK: scf.for {{.*}} {
295+
// CHECK: vector.transfer_read
296+
// CHECK: vector.transfer_write
297+
// CHECK: vector.transfer_write
298+
// CHECK: vector.transfer_read
299+
// CHECK: vector.transfer_write
300+
301+
func.func @collapse_shape_and_read_from_collapse(%in_0: memref<20xi32>, %vec: vector<1x4x1xi32>) {
302+
%c0_i32 = arith.constant 0 : i32
303+
%c0 = arith.constant 0 : index
304+
%c4 = arith.constant 4 : index
305+
%c20 = arith.constant 20 : index
306+
307+
%alloca = memref.alloca() {alignment = 64 : i64} : memref<1x4x1xi32>
308+
%collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x4x1xi32> into memref<4xi32>
309+
scf.for %arg0 = %c0 to %c20 step %c4 {
310+
%subview = memref.subview %in_0[%arg0] [4] [1] : memref<20xi32> to memref<4xi32, strided<[1], offset: ?>>
311+
%1 = vector.transfer_read %subview[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32, strided<[1], offset: ?>>, vector<4xi32>
312+
vector.transfer_write %1, %collapse_shape[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
313+
// $alloca and $collapse_shape alias
314+
vector.transfer_write %vec, %alloca[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
315+
%2 = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<4xi32>, vector<4xi32>
316+
vector.transfer_write %2, %subview[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32, strided<[1], offset: ?>>
317+
}
318+
return
319+
}
320+
321+
// The same test except writing to the expanded source first (same as the
322+
// previous collapse test but for expand).
323+
324+
// CHECK-LABEL: func.func @expand_shape_and_read_from_expand
325+
// CHECK: scf.for {{.*}} {
326+
// CHECK: vector.transfer_read
327+
// CHECK: vector.transfer_write
328+
// CHECK: vector.transfer_write
329+
// CHECK: vector.transfer_read
330+
// CHECK: vector.transfer_write
331+
332+
func.func @expand_shape_and_read_from_expand(%in_0: memref<1x20x1xi32>, %vec: vector<4xi32>) {
333+
%c0_i32 = arith.constant 0 : i32
334+
%c0 = arith.constant 0 : index
335+
%c4 = arith.constant 4 : index
336+
%c20 = arith.constant 20 : index
337+
338+
%alloca = memref.alloca() {alignment = 64 : i64} : memref<4xi32>
339+
%expand_shape = memref.expand_shape %alloca [[0, 1, 2]] output_shape [1, 4, 1] : memref<4xi32> into memref<1x4x1xi32>
340+
scf.for %arg0 = %c0 to %c20 step %c4 {
341+
%subview = memref.subview %in_0[0, %arg0, 0] [1, 4, 1] [1, 1, 1] : memref<1x20x1xi32> to memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
342+
%1 = vector.transfer_read %subview[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>, vector<1x4x1xi32>
343+
vector.transfer_write %1, %expand_shape[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32>
344+
// $alloca and $expand_shape alias
345+
vector.transfer_write %vec, %alloca[%c0] {in_bounds = [true]} : vector<4xi32>, memref<4xi32>
346+
%2 = vector.transfer_read %expand_shape[%c0, %c0, %c0], %c0_i32 {in_bounds = [true, true, true]} : memref<1x4x1xi32>, vector<1x4x1xi32>
347+
vector.transfer_write %2, %subview[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>>
348+
}
349+
return
350+
}
351+
260352
// CHECK-LABEL: func @forward_dead_store_dynamic_same_index
261353
// CHECK-NOT: vector.transfer_write
262354
// CHECK-NOT: vector.transfer_read

0 commit comments

Comments
 (0)