-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Fix unit dim dropping pattern for masked writes #74038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][vector] Fix unit dim dropping pattern for masked writes #74038
Conversation
This does the same as llvm#72142 for vector.transfer_write. Previously the pattern would silently drop the mask.
@llvm/pr-subscribers-mlir Author: Quinn Dawkins (qedawkins) ChangesThis does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask. Full diff: https://github.com/llvm/llvm-project/pull/74038.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..0dc097158a4a55d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
opToErase.push_back(read.getOperation());
}
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
- SmallVector<int64_t> reducedShape;
- llvm::copy_if(shape, std::back_inserter(reducedShape),
- [](int64_t dimSize) { return dimSize != 1; });
- return reducedShape;
-}
-
/// Converts OpFoldResults to int64_t shape without unit dims.
static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
SmallVector<int64_t> reducedShape;
@@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor type.
- if (!sourceType || !sourceType.hasStaticShape())
- return failure();
- if (sourceType.getNumElements() != vectorType.getNumElements())
+ if (!sourceType)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
@@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern
return failure();
// Check if the reduced vector shape matches the reduced destination shape.
// Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
+ auto reducedVectorType = trimNonScalableUnitDims(vectorType);
+ if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
+
+ Value maskOp = transferWriteOp.getMask();
+ if (maskOp) {
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ transferWriteOp,
+ "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+ FailureOr<Value> rankReducedCreateMask =
+ createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+ if (failed(rankReducedCreateMask))
+ return failure();
+ maskOp = *rankReducedCreateMask;
+ }
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
- VectorType reducedVectorType = VectorType::get(
- getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+ SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
+ transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
+ identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 735915d43565389..d65708068862f46 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
+func.func @masked_transfer_write_and_vector_rank_reducing(
+ %arg : memref<1x1x3x1x16x1xf32>,
+ %vec : vector<1x3x1x16x1xf32>,
+ %mask_dim1 : index,
+ %mask_dim2 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+ vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>,
+// CHECK-SAME: %[[MASKDIM1:.+]]: index,
+// CHECK-SAME: %[[MASKDIM2:.+]]: index
+// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %vec : vector<[16]x1xi8>,
+ %mask_dim0 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+ vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>,
+// CHECK-SAME: %[[MASK_DIM0:.+]]: index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
+// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
+
/// Only masks operands of vector.create_mask are currently supported.
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
|
@llvm/pr-subscribers-mlir-vector Author: Quinn Dawkins (qedawkins) ChangesThis does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask. Full diff: https://github.com/llvm/llvm-project/pull/74038.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..0dc097158a4a55d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
opToErase.push_back(read.getOperation());
}
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
- SmallVector<int64_t> reducedShape;
- llvm::copy_if(shape, std::back_inserter(reducedShape),
- [](int64_t dimSize) { return dimSize != 1; });
- return reducedShape;
-}
-
/// Converts OpFoldResults to int64_t shape without unit dims.
static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
SmallVector<int64_t> reducedShape;
@@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern
Value source = transferWriteOp.getSource();
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
// TODO: support tensor type.
- if (!sourceType || !sourceType.hasStaticShape())
- return failure();
- if (sourceType.getNumElements() != vectorType.getNumElements())
+ if (!sourceType)
return failure();
// TODO: generalize this pattern, relax the requirements here.
if (transferWriteOp.hasOutOfBoundsDim())
@@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern
return failure();
// Check if the reduced vector shape matches the reduced destination shape.
// Otherwise, this case is not supported yet.
- int vectorReducedRank = getReducedRank(vectorType.getShape());
- if (reducedRank != vectorReducedRank)
+ auto reducedVectorType = trimNonScalableUnitDims(vectorType);
+ if (reducedRank != reducedVectorType.getRank())
return failure();
if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
return getConstantIntValue(v) != static_cast<int64_t>(0);
}))
return failure();
+
+ Value maskOp = transferWriteOp.getMask();
+ if (maskOp) {
+ auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+ if (!createMaskOp)
+ return rewriter.notifyMatchFailure(
+ transferWriteOp,
+ "unsupported mask op, only 'vector.create_mask' is "
+ "currently supported");
+ FailureOr<Value> rankReducedCreateMask =
+ createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+ if (failed(rankReducedCreateMask))
+ return failure();
+ maskOp = *rankReducedCreateMask;
+ }
Value reducedShapeSource =
rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
- VectorType reducedVectorType = VectorType::get(
- getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+ SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
+ transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
+ identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
return success();
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 735915d43565389..d65708068862f46 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
+func.func @masked_transfer_write_and_vector_rank_reducing(
+ %arg : memref<1x1x3x1x16x1xf32>,
+ %vec : vector<1x3x1x16x1xf32>,
+ %mask_dim1 : index,
+ %mask_dim2 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+ vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+// CHECK-SAME: {{.*}}: vector<1x3x1x16x1xf32>,
+// CHECK-SAME: %[[MASKDIM1:.+]]: index,
+// CHECK-SAME: %[[MASKDIM2:.+]]: index
+// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing(
+ %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+ %vec : vector<[16]x1xi8>,
+ %mask_dim0 : index) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %pad = arith.constant 0 : i8
+ %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
+ vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+ vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
+ return
+}
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+// CHECK-SAME: %[[ARG:.+]]: memref<?x1xi8
+// CHECK-SAME: %{{.*}}: vector<[16]x1xi8>,
+// CHECK-SAME: %[[MASK_DIM0:.+]]: index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
+// CHECK: %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+// CHECK: vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
+
/// Only masks operands of vector.create_mask are currently supported.
func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
Outdated
Show resolved
Hide resolved
I've hit this a few times recently too, seems like something has changed 😕 |
Well it worked now I guess :/ |
This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.