Skip to content

Commit 226896c

Browse files
[mlir][linalg] Fix bug in vector transfer hoisting
Do not hoist vector transfers that do not match exactly. In particular, do not hoist transfers with different vector types. This has lead to invalid IR (yielded vector type is different from iter_arg type) in downstream projects. Differential Revision: https://reviews.llvm.org/D155052
1 parent 67f1e8d commit 226896c

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,14 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
135135
<< "\n");
136136

137137
// Approximate aliasing by checking that:
138-
// 1. indices are the same,
138+
// 1. indices, vector type and permutation map are the same (i.e., the
139+
// transfer_read/transfer_write ops are matching),
139140
// 2. no other operations in the loop access the same memref except
140141
// for transfer_read/transfer_write accessing statically disjoint
141142
// slices.
142-
if (transferRead.getIndices() != transferWrite.getIndices() &&
143-
transferRead.getVectorType() == transferWrite.getVectorType())
143+
if (transferRead.getIndices() != transferWrite.getIndices() ||
144+
transferRead.getVectorType() != transferWrite.getVectorType() ||
145+
transferRead.getPermutationMap() != transferWrite.getPermutationMap())
144146
return WalkResult::advance();
145147

146148
// TODO: may want to memoize this information for performance but it

mlir/test/Dialect/Linalg/hoisting.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,37 @@ transform.sequence failures(propagate) {
722722
transform.structured.hoist_redundant_vector_transfers %0
723723
: (!transform.any_op) -> !transform.any_op
724724
}
725+
726+
// -----
727+
728+
// The transfers in this test case cannot be hoisted and replaced by a vector
729+
// iter_arg because they do not match.
730+
731+
// CHECK-LABEL: func.func @non_matching_transfers(
732+
// CHECK: scf.for {{.*}} {
733+
// CHECK: vector.transfer_read
734+
// CHECK: vector.transfer_write
735+
// CHECK: }
736+
func.func @non_matching_transfers(%m: memref<6x1x7x32xf32>) {
737+
%c0 = arith.constant 0 : index
738+
%c1024 = arith.constant 1024 : index
739+
%c128 = arith.constant 128 : index
740+
%cst = arith.constant dense<5.5> : vector<6x7x32xf32>
741+
%cst_0 = arith.constant 0.0 : f32
742+
scf.for %iv = %c0 to %c1024 step %c128 {
743+
%read = vector.transfer_read %m[%c0, %c0, %c0, %c0], %cst_0 {in_bounds = [true, true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>} : memref<6x1x7x32xf32>, vector<6x7x32xf32>
744+
%added = arith.addf %read, %cst : vector<6x7x32xf32>
745+
%bc = vector.broadcast %added : vector<6x7x32xf32> to vector<1x6x7x32xf32>
746+
%tr = vector.transpose %bc, [1, 0, 2, 3] : vector<1x6x7x32xf32> to vector<6x1x7x32xf32>
747+
vector.transfer_write %tr, %m[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true, true]} : vector<6x1x7x32xf32>, memref<6x1x7x32xf32>
748+
}
749+
return
750+
}
751+
752+
transform.sequence failures(propagate) {
753+
^bb1(%arg1: !transform.any_op):
754+
%0 = transform.structured.match ops{["func.func"]} in %arg1
755+
: (!transform.any_op) -> !transform.any_op
756+
transform.structured.hoist_redundant_vector_transfers %0
757+
: (!transform.any_op) -> !transform.any_op
758+
}

0 commit comments

Comments
 (0)