Skip to content

Commit 6040044

Browse files
[mlir][vector] VectorToSCF: Omit redundant out-of-bounds check
There was a bug in `TransferWriteNonPermutationLowering`, a pattern that extends the permutation map of a TransferWriteOp with leading transfer dimensions of size ones. These newly added transfer dimensions are always in-bounds, because the starting point of any dimension is in-bounds. VectorToSCF inserts out-of-bounds checks based on the "in_bounds" attribute and dims that are marked as out-of-bounds but that are actually always in-bounds lead to unnecessary "scf.if" ops. Differential Revision: https://reviews.llvm.org/D155196
1 parent 2ac9920 commit 6040044

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,12 @@ struct TransferWriteNonPermutationLowering
272272
exprs.append(map.getResults().begin(), map.getResults().end());
273273
AffineMap newMap =
274274
AffineMap::get(map.getNumDims(), 0, exprs, op.getContext());
275-
ArrayAttr newInBoundsAttr;
276-
if (op.getInBounds()) {
277-
// All the new dimensions added are inbound.
278-
SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
279-
for (Attribute attr : op.getInBounds().value().getValue()) {
280-
newInBoundsValues.push_back(cast<BoolAttr>(attr).getValue());
281-
}
282-
newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
275+
// All the new dimensions added are inbound.
276+
SmallVector<bool> newInBoundsValues(missingInnerDim.size(), true);
277+
for (int64_t i = 0, e = op.getVectorType().getRank(); i < e; ++i) {
278+
newInBoundsValues.push_back(op.isDimInBounds(i));
283279
}
280+
ArrayAttr newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues);
284281
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
285282
op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap),
286283
newMask, newInBoundsAttr);

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,9 @@ func.func @materialize_write(%M: index, %N: index, %O: index, %P: index) {
171171
// CHECK: %[[S1:.*]] = affine.apply #[[$ADD]](%[[I1]], %[[I5]])
172172
// CHECK: %[[VECTOR_VIEW3:.*]] = vector.type_cast %[[VECTOR_VIEW2]] : memref<3x4xvector<1x5xf32>> to memref<3x4x1xvector<5xf32>>
173173
// CHECK: scf.for %[[I6:.*]] = %[[C0]] to %[[C1]] step %[[C1]] {
174-
// CHECK: scf.if
175-
// CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I2]], %[[I6]])
176-
// CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW3]][%[[I4]], %[[I5]], %[[I6]]] : memref<3x4x1xvector<5xf32>>
177-
// CHECK: vector.transfer_write %[[VEC]], %{{.*}}[%[[S3]], %[[S1]], %[[S0]], %[[I3]]] : vector<5xf32>, memref<?x?x?x?xf32>
174+
// CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I2]], %[[I6]])
175+
// CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW3]][%[[I4]], %[[I5]], %[[I6]]] : memref<3x4x1xvector<5xf32>>
176+
// CHECK: vector.transfer_write %[[VEC]], %{{.*}}[%[[S3]], %[[S1]], %[[S0]], %[[I3]]] : vector<5xf32>, memref<?x?x?x?xf32>
178177
// CHECK: }
179178
// CHECK: }
180179
// CHECK: }

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ func.func @transfer_write_broadcast_unit_dim(
355355
vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref<?x?x?x?xf32>
356356
// CHECK: %[[NEW_VEC2:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32>
357357
// CHECK: %[[NEW_VEC3:.*]] = vector.transpose %[[NEW_VEC2]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32>
358-
// CHECK: vector.transfer_write %[[NEW_VEC3]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x16x1xf32>, memref<?x?x?x?xf32>
358+
// CHECK: vector.transfer_write %[[NEW_VEC3]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true]} : vector<8x16x1xf32>, memref<?x?x?x?xf32>
359359

360360
return %0 : tensor<?x?x?x?xf32>
361361
}

0 commit comments

Comments
 (0)