-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns #112394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns #112394
Conversation
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesUpdates %v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
} : vector<3x2xi1> -> vector<3x2xi8> Full diff: https://github.com/llvm/llvm-project/pull/112394.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..6cdfa645d78f9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
/// inserting a memref.subview dropping those unit dims. The vector shapes are
/// also reduced accordingly.
class TransferReadDropUnitDimsPattern
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,15 +408,23 @@ class TransferReadDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
- auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+ Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));
+
+ if (maskingOp) {
+ auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+ maskingOp.getMask());
+ newTransferReadOp = mlir::vector::maskOperation(
+ rewriter, newTransferReadOp, shapeCastMask);
+ }
+
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, vectorType, newTransferReadOp);
- rewriter.replaceOp(transferReadOp, shapeCast);
+ loc, vectorType, newTransferReadOp->getResults()[0]);
- return success();
+ return shapeCast;
}
};
@@ -422,11 +432,13 @@ class TransferReadDropUnitDimsPattern
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
/// vector shapes are also reduced accordingly.
class TransferWriteDropUnitDimsPattern
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
- auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+ auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
- identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+ Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+ loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+ maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+ if (maskingOp) {
+ auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+ maskingOp.getMask());
+ newXferWrite =
+ mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+ }
- return success();
+ if (transferWriteOp.hasPureTensorSemantics())
+ return newXferWrite->getResults()[0];
+
+ // With Memref semantics, there's no return value. Use empty value to signal
+ // success.
+ return Value();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..9b61c8ea76f962 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
func.func @transfer_read_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
%c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.mask %mask {
+ vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+ } : vector<3x2xi1> -> vector<3x2xi8>
+ return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
+// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+// CHECK: vector.mask %[[MASK]]
+// CHECK-SAME: vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %vec : vector<3x2xi8>) {
+
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+func.func @transfer_write_rank_reducing_masked(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %vec : vector<3x2xi8>,
+ %mask: vector<3x2xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.mask %mask {
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+ } : vector<3x2xi1>
+ return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
+// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
+// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+// CHECK: vector.mask %[[MASK]]
+// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
func.func @transfer_read_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
%c0 = arith.constant 0 : index
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) ChangesUpdates %v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
} : vector<3x2xi1> -> vector<3x2xi8> Full diff: https://github.com/llvm/llvm-project/pull/112394.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..6cdfa645d78f9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
/// inserting a memref.subview dropping those unit dims. The vector shapes are
/// also reduced accordingly.
class TransferReadDropUnitDimsPattern
- : public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto loc = transferReadOp.getLoc();
Value vector = transferReadOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,15 +408,23 @@ class TransferReadDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
- auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+ Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
rewriter.getBoolArrayAttr(inBounds));
+
+ if (maskingOp) {
+ auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+ maskingOp.getMask());
+ newTransferReadOp = mlir::vector::maskOperation(
+ rewriter, newTransferReadOp, shapeCastMask);
+ }
+
auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
- loc, vectorType, newTransferReadOp);
- rewriter.replaceOp(transferReadOp, shapeCast);
+ loc, vectorType, newTransferReadOp->getResults()[0]);
- return success();
+ return shapeCast;
}
};
@@ -422,11 +432,13 @@ class TransferReadDropUnitDimsPattern
/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
/// vector shapes are also reduced accordingly.
class TransferWriteDropUnitDimsPattern
- : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+ using MaskableOpRewritePattern::MaskableOpRewritePattern;
- LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
- PatternRewriter &rewriter) const override {
+ FailureOr<Value>
+ matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+ vector::MaskingOpInterface maskingOp,
+ PatternRewriter &rewriter) const override {
auto loc = transferWriteOp.getLoc();
Value vector = transferWriteOp.getVector();
VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
SmallVector<Value> zeros(reducedRank, c0);
auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
- auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+ auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType, vector);
- rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
- transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
- identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+ Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+ loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+ maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+ if (maskingOp) {
+ auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+ maskingOp.getMask());
+ newXferWrite =
+ mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+ }
- return success();
+ if (transferWriteOp.hasPureTensorSemantics())
+ return newXferWrite->getResults()[0];
+
+ // With Memref semantics, there's no return value. Use empty value to signal
+ // success.
+ return Value();
}
};
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..9b61c8ea76f962 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
func.func @transfer_read_rank_reducing(
%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
%c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_read %[[SUBVIEW]]
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0 : i8
+ %v = vector.mask %mask {
+ vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+ memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+ } : vector<3x2xi1> -> vector<3x2xi8>
+ return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
+// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+// CHECK: vector.mask %[[MASK]]
+// CHECK-SAME: vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %vec : vector<3x2xi8>) {
+
%c0 = arith.constant 0 : index
vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+func.func @transfer_write_rank_reducing_masked(
+ %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+ %vec : vector<3x2xi8>,
+ %mask: vector<3x2xi1>) {
+ %c0 = arith.constant 0 : index
+ vector.mask %mask {
+ vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+ vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+ } : vector<3x2xi1>
+ return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
+// CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
+// CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
+// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+// CHECK: vector.mask %[[MASK]]
+// CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
func.func @transfer_read_and_vector_rank_reducing(
%arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
%c0 = arith.constant 0 : index
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking for some tests with shape_cast ops generated; I found that it does not support the below case. Can you take a look?
Note: it is a test in the original file and I just added a vector.mask
op around it.
func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
%arg : memref<1x1x1x1x1xf32>,
%mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.0 : f32
%v = vector.mask %mask {
vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
} : vector<1x1x1xi1> -> vector<1x1x1xf32>
return %v : vector<1x1x1xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
%func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
transform.apply_patterns to %func_op {
transform.apply_patterns.vector.rank_reducing_subview_patterns
} : !transform.op<"func.func">
transform.yield
}
}
|
||
if (maskingOp) { | ||
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>( | ||
loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using std::nullopt
is clearer instead of {}
.
The error log:
|
Updates `TransferWriteDropUnitDimsPattern` and `TransferReadDropUnitDimsPattern` to inherit from `MaskableOpRewritePattern` so that masked versions of xfer_read/xfer_write Ops are also supported: ```mlir %v = vector.mask %mask { vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> } : vector<3x2xi1> -> vector<3x2xi8> ```
c6a91f5
to
0de9b8c
Compare
…patterns Bail out for with 0D cases
Great catch, thanks for checking! Turns out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I must have missed this one in my pattern maskification spree. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Pretty good to see how MaskableOpRewritePattern::MaskableOpRewritePattern
allows us to add support for masked ops with minimal changes, reusing the same patterns!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, thanks for checking! Turns out vector.mask doesn't support 0-d vectors. This should be easy to extend, but let's bail out for now - I am a bit blocked 😅
SGTM, thanks for taking a look!
llvm#112394) Updates `TransferWriteDropUnitDimsPattern` and `TransferReadDropUnitDimsPattern` to inherit from `MaskableOpRewritePattern` so that masked versions of xfer_read/xfer_write Ops are also supported: ```mlir %v = vector.mask %mask { vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> } : vector<3x2xi1> -> vector<3x2xi8> ```
llvm#112394) Updates `TransferWriteDropUnitDimsPattern` and `TransferReadDropUnitDimsPattern` to inherit from `MaskableOpRewritePattern` so that masked versions of xfer_read/xfer_write Ops are also supported: ```mlir %v = vector.mask %mask { vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8> } : vector<3x2xi1> -> vector<3x2xi8> ```
Updates
TransferWriteDropUnitDimsPattern
andTransferReadDropUnitDimsPattern
to inherit fromMaskableOpRewritePattern
so that masked versions ofxfer_read/xfer_write Ops are also supported: