Skip to content

[MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns #112394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

banach-space
Copy link
Contributor

Updates TransferWriteDropUnitDimsPattern and
TransferReadDropUnitDimsPattern to inherit from
MaskableOpRewritePattern so that masked versions of
xfer_read/xfer_write Ops are also supported:

    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Updates TransferWriteDropUnitDimsPattern and
TransferReadDropUnitDimsPattern to inherit from
MaskableOpRewritePattern so that masked versions of
xfer_read/xfer_write Ops are also supported:

    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref&lt;1x1x3x2xi8, strided&lt;[6, 6, 2, 1], offset: ?&gt;&gt;, vector&lt;3x2xi8&gt;
    } : vector&lt;3x2xi1&gt; -&gt; vector&lt;3x2xi8&gt;

Full diff: https://github.com/llvm/llvm-project/pull/112394.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+42-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+47-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..6cdfa645d78f9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
 /// inserting a memref.subview dropping those unit dims. The vector shapes are
 /// also reduced accordingly.
 class TransferReadDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferReadOp.getLoc();
     Value vector = transferReadOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,15 +408,23 @@ class TransferReadDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+    Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
         transferReadOp.getPadding(), maskOp,
         rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newTransferReadOp = mlir::vector::maskOperation(
+          rewriter, newTransferReadOp, shapeCastMask);
+    }
+
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, vectorType, newTransferReadOp);
-    rewriter.replaceOp(transferReadOp, shapeCast);
+        loc, vectorType, newTransferReadOp->getResults()[0]);
 
-    return success();
+    return shapeCast;
   }
 };
 
@@ -422,11 +432,13 @@ class TransferReadDropUnitDimsPattern
 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
 /// vector shapes are also reduced accordingly.
 class TransferWriteDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferWriteOp.getLoc();
     Value vector = transferWriteOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+    auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
-        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+    Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+        loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+        maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newXferWrite =
+          mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+    }
 
-    return success();
+    if (transferWriteOp.hasPureTensorSemantics())
+      return newXferWrite->getResults()[0];
+
+    // With Memref semantics, there's no return value. Use empty value to signal
+    // success.
+    return Value();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..9b61c8ea76f962 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
 func.func @transfer_read_rank_reducing(
       %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
     %c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]
 
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.mask %mask {
+      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+    } : vector<3x2xi1> -> vector<3x2xi8>
+    return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:  vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>) {
+
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
       vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
 
+func.func @transfer_write_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>,
+      %mask: vector<3x2xi1>) {
+    %c0 = arith.constant 0 : index
+    vector.mask %mask {
+      vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+        vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+    } : vector<3x2xi1>
+    return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[VEC:.+]]: vector<3x2xi8>
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
 func.func @transfer_read_and_vector_rank_reducing(
       %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
     %c0 = arith.constant 0 : index

@llvmbot
Copy link
Member

llvmbot commented Oct 15, 2024

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

Updates TransferWriteDropUnitDimsPattern and
TransferReadDropUnitDimsPattern to inherit from
MaskableOpRewritePattern so that masked versions of
xfer_read/xfer_write Ops are also supported:

    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref&lt;1x1x3x2xi8, strided&lt;[6, 6, 2, 1], offset: ?&gt;&gt;, vector&lt;3x2xi8&gt;
    } : vector&lt;3x2xi1&gt; -&gt; vector&lt;3x2xi8&gt;

Full diff: https://github.com/llvm/llvm-project/pull/112394.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+42-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+47-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index e05c801121ffc4..6cdfa645d78f9b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -354,11 +354,13 @@ namespace {
 /// inserting a memref.subview dropping those unit dims. The vector shapes are
 /// also reduced accordingly.
 class TransferReadDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferReadOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferReadOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferReadOp transferReadOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferReadOp.getLoc();
     Value vector = transferReadOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -406,15 +408,23 @@ class TransferReadDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
+    Operation *newTransferReadOp = rewriter.create<vector::TransferReadOp>(
         loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
         transferReadOp.getPadding(), maskOp,
         rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newTransferReadOp = mlir::vector::maskOperation(
+          rewriter, newTransferReadOp, shapeCastMask);
+    }
+
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
-        loc, vectorType, newTransferReadOp);
-    rewriter.replaceOp(transferReadOp, shapeCast);
+        loc, vectorType, newTransferReadOp->getResults()[0]);
 
-    return success();
+    return shapeCast;
   }
 };
 
@@ -422,11 +432,13 @@ class TransferReadDropUnitDimsPattern
 /// has unit dims, by inserting a `memref.subview` dropping those unit dims. The
 /// vector shapes are also reduced accordingly.
 class TransferWriteDropUnitDimsPattern
-    : public OpRewritePattern<vector::TransferWriteOp> {
-  using OpRewritePattern::OpRewritePattern;
+    : public vector::MaskableOpRewritePattern<vector::TransferWriteOp> {
+  using MaskableOpRewritePattern::MaskableOpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp,
-                                PatternRewriter &rewriter) const override {
+  FailureOr<Value>
+  matchAndRewriteMaskableOp(vector::TransferWriteOp transferWriteOp,
+                            vector::MaskingOpInterface maskingOp,
+                            PatternRewriter &rewriter) const override {
     auto loc = transferWriteOp.getLoc();
     Value vector = transferWriteOp.getVector();
     VectorType vectorType = cast<VectorType>(vector.getType());
@@ -474,13 +486,26 @@ class TransferWriteDropUnitDimsPattern
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
     SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
-    auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
+    auto shapeCastSrc = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
-    rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
-        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
+    Operation *newXferWrite = rewriter.create<vector::TransferWriteOp>(
+        loc, Type(), shapeCastSrc, reducedShapeSource, zeros, identityMap,
+        maskOp, rewriter.getBoolArrayAttr(inBounds));
+
+    if (maskingOp) {
+      auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
+          loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
+          maskingOp.getMask());
+      newXferWrite =
+          mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);
+    }
 
-    return success();
+    if (transferWriteOp.hasPureTensorSemantics())
+      return newXferWrite->getResults()[0];
+
+    // With Memref semantics, there's no return value. Use empty value to signal
+    // success.
+    return Value();
   }
 };
 
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index e9d12b044e2c7e..9b61c8ea76f962 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -1,5 +1,9 @@
 // RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
+//-----------------------------------------------------------------------------
+// [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
+//-----------------------------------------------------------------------------
+
 func.func @transfer_read_rank_reducing(
       %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>) -> vector<3x2xi8> {
     %c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]]
 
-func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) {
+func.func @transfer_read_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %mask: vector<3x2xi1>) -> vector<3x2xi8> {
+    %c0 = arith.constant 0 : index
+    %cst = arith.constant 0 : i8
+    %v = vector.mask %mask {
+      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
+        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
+    } : vector<3x2xi1> -> vector<3x2xi8>
+    return %v : vector<3x2xi8>
+}
+// CHECK-LABEL: func @transfer_read_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:  vector.transfer_read %[[SUBVIEW]]
+
+func.func @transfer_write_rank_reducing(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>) {
+
     %c0 = arith.constant 0 : index
     vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
       vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
 //  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
 //       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
 
+func.func @transfer_write_rank_reducing_masked(
+      %arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>,
+      %vec : vector<3x2xi8>,
+      %mask: vector<3x2xi1>) {
+    %c0 = arith.constant 0 : index
+    vector.mask %mask {
+      vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] :
+        vector<3x2xi8>, memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>
+    } : vector<3x2xi1>
+    return
+}
+// CHECK-LABEL: func @transfer_write_rank_reducing_masked
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x2xi8
+//  CHECK-SAME:     %[[VEC:.+]]: vector<3x2xi8>
+//  CHECK-SAME:     %[[MASK:.+]]: vector<3x2xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
+//       CHECK:   vector.mask %[[MASK]]
+//  CHECK-SAME:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]
+
 func.func @transfer_read_and_vector_rank_reducing(
       %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> {
     %c0 = arith.constant 0 : index

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking for some tests with shape_cast ops generated; I found that it does not support the below case. Can you take a look?

Note: it is a test in the original file and I just added a vector.mask op around it.

func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
      %arg : memref<1x1x1x1x1xf32>,
      %mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.0 : f32
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
      memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
    } : vector<1x1x1xi1> -> vector<1x1x1xf32>
    return %v : vector<1x1x1xf32>
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
    transform.apply_patterns to %func_op {
      transform.apply_patterns.vector.rank_reducing_subview_patterns
    } : !transform.op<"func.func">
    transform.yield
  }
}


if (maskingOp) {
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using std::nullopt is clearer instead of {}.

@hanhanW
Copy link
Contributor

hanhanW commented Oct 17, 2024

The error log:

zz.mlir:19:7: error: 'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```
@banach-space banach-space force-pushed the andrzej/update_drop_unit_dim_patterns branch from c6a91f5 to 0de9b8c Compare October 18, 2024 17:41
@banach-space
Copy link
Contributor Author

I was looking for some tests with shape_cast ops generated; I found that it does not support the below case. Can you take a look?

Note: it is a test in the original file and I just added a vector.mask op around it.

func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
      %arg : memref<1x1x1x1x1xf32>,
      %mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.0 : f32
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst :
      memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
    } : vector<1x1x1xi1> -> vector<1x1x1xf32>
    return %v : vector<1x1x1xf32>
}
module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) {
    %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func">
    transform.apply_patterns to %func_op {
      transform.apply_patterns.vector.rank_reducing_subview_patterns
    } : !transform.op<"func.func">
    transform.yield
  }
}

Great catch, thanks for checking! Turns out vector.mask doesn't support 0-d vectors. This should be easy to extend, but let's bail out for now - I am a bit blocked 😅

Copy link
Contributor

@nujaa nujaa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I must have missed this one in my pattern maskification spree. Thanks.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Pretty good to see how MaskableOpRewritePattern::MaskableOpRewritePattern allows us to add support for masked ops with minimal changes, reusing the same patterns!

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, thanks for checking! Turns out vector.mask doesn't support 0-d vectors. This should be easy to extend, but let's bail out for now - I am a bit blocked 😅

SGTM, thanks for taking a look!

@banach-space banach-space merged commit 0cf7aaf into llvm:main Oct 26, 2024
8 checks passed
winner245 pushed a commit to winner245/llvm-project that referenced this pull request Oct 26, 2024
llvm#112394)

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```
NoumanAmir657 pushed a commit to NoumanAmir657/llvm-project that referenced this pull request Nov 4, 2024
llvm#112394)

Updates `TransferWriteDropUnitDimsPattern` and
`TransferReadDropUnitDimsPattern` to inherit from
`MaskableOpRewritePattern` so that masked versions of
xfer_read/xfer_write Ops are also supported:

```mlir
    %v = vector.mask %mask {
      vector.transfer_read %arg[%c0, %c0, %c0, %c0], %cst :
        memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, vector<3x2xi8>
    } : vector<3x2xi1> -> vector<3x2xi8>
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants