Skip to content

Commit dc82547

Browse files
author
gysit
committed
[mlir][vector] Make write permutation lowering work with tensors.
Use type inference when building the TransferWriteOp in the TransferWritePermutationLowering. Previously, the result type has been set to Type() which triggers an assertion if the pattern is used with tensors instead of memrefs. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D118758
1 parent 31cca9e commit dc82547

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ struct TransferWritePermutationLowering
185185
auto newMap = AffineMap::getMinorIdentityMap(
186186
map.getNumDims(), map.getNumResults(), rewriter.getContext());
187187
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
188-
op, Type(), newVec, op.source(), op.indices(),
189-
AffineMapAttr::get(newMap), newMask, newInBoundsAttr);
188+
op, newVec, op.source(), op.indices(), AffineMapAttr::get(newMap),
189+
newMask, newInBoundsAttr);
190190

191191
return success();
192192
}

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,21 +327,24 @@ func @transfer_read_permutations(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?x?x?
327327
// -----
328328

329329
// CHECK-LABEL: func @transfer_write_permutations
330-
func @transfer_write_permutations(%arg0 : memref<?x?x?x?xf32>,
331-
%v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> () {
330+
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?x?x?xf32>
331+
// CHECK-SAME: %[[ARG1:.*]]: tensor<?x?x?x?xf32>
332+
func @transfer_write_permutations(
333+
%arg0 : memref<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
334+
%v1 : vector<7x14x8x16xf32>, %v2 : vector<8x16xf32>) -> tensor<?x?x?x?xf32> {
332335
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
333336
%c0 = arith.constant 0 : index
334337
%m = arith.constant 1 : i1
335338

336339
%mask0 = splat %m : vector<7x14x8x16xi1>
337-
vector.transfer_write %v1, %arg0[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, memref<?x?x?x?xf32>
340+
%0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0], %mask0 {in_bounds = [true, false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)>} : vector<7x14x8x16xf32>, tensor<?x?x?x?xf32>
338341
// CHECK: %[[NEW_MASK0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xi1> to vector<8x14x16x7xi1>
339342
// CHECK: %[[NEW_VEC0:.*]] = vector.transpose %{{.*}} [2, 1, 3, 0] : vector<7x14x8x16xf32> to vector<8x14x16x7xf32>
340-
// CHECK: vector.transfer_write %[[NEW_VEC0]], %arg0[%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, memref<?x?x?x?xf32>
343+
// CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC0]], %[[ARG1]][%c0, %c0, %c0, %c0], %[[NEW_MASK0]] {in_bounds = [false, false, true, true]} : vector<8x14x16x7xf32>, tensor<?x?x?x?xf32>
341344

342345
vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
343346
// CHECK: %[[NEW_VEC1:.*]] = vector.transpose %{{.*}} [1, 0] : vector<8x16xf32> to vector<16x8xf32>
344-
// CHECK: vector.transfer_write %[[NEW_VEC1]], %arg0[%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
347+
// CHECK: vector.transfer_write %[[NEW_VEC1]], %[[ARG0]][%c0, %c0, %c0, %c0] : vector<16x8xf32>, memref<?x?x?x?xf32>
345348

346-
return
349+
return %0 : tensor<?x?x?x?xf32>
347350
}

0 commit comments

Comments
 (0)