Skip to content

[mlir][vector] Fix unit dim dropping pattern for masked writes #74038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 1, 2023

Conversation

qedawkins
Copy link
Contributor

This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.

This does the same as llvm#72142 for vector.transfer_write. Previously the
pattern would silently drop the mask.
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir

Author: Quinn Dawkins (qedawkins)

Changes

This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.


Full diff: https://github.com/llvm/llvm-project/pull/74038.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+21-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+44)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..0dc097158a4a55d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
-  SmallVector<int64_t> reducedShape;
-  llvm::copy_if(shape, std::back_inserter(reducedShape),
-                [](int64_t dimSize) { return dimSize != 1; });
-  return reducedShape;
-}
-
 /// Converts OpFoldResults to int64_t shape without unit dims.
 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
   SmallVector<int64_t> reducedShape;
@@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
     // TODO: support tensor type.
-    if (!sourceType || !sourceType.hasStaticShape())
-      return failure();
-    if (sourceType.getNumElements() != vectorType.getNumElements())
+    if (!sourceType)
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
@@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern
       return failure();
     // Check if the reduced vector shape matches the reduced destination shape.
     // Otherwise, this case is not supported yet.
-    int vectorReducedRank = getReducedRank(vectorType.getShape());
-    if (reducedRank != vectorReducedRank)
+    auto reducedVectorType = trimNonScalableUnitDims(vectorType);
+    if (reducedRank != reducedVectorType.getRank())
       return failure();
     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
           return getConstantIntValue(v) != static_cast<int64_t>(0);
         }))
       return failure();
+
+    Value maskOp = transferWriteOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return rewriter.notifyMatchFailure(
+            transferWriteOp,
+            "unsupported mask op, only 'vector.create_mask' is "
+            "currently supported");
+      FailureOr<Value> rankReducedCreateMask =
+          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      if (failed(rankReducedCreateMask))
+        return failure();
+      maskOp = *rankReducedCreateMask;
+    }
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
-    VectorType reducedVectorType = VectorType::get(
-        getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+    SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
+        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
+        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 735915d43565389..d65708068862f46 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
 
+func.func @masked_transfer_write_and_vector_rank_reducing(
+      %arg : memref<1x1x3x1x16x1xf32>,
+      %vec : vector<1x3x1x16x1xf32>,
+      %mask_dim1 : index,
+      %mask_dim2 : index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+      vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+//  CHECK-SAME:     {{.*}}: vector<1x3x1x16x1xf32>,
+//  CHECK-SAME:     %[[MASKDIM1:.+]]: index,
+//  CHECK-SAME:     %[[MASKDIM2:.+]]: index
+//       CHECK:   %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %vec : vector<[16]x1xi8>,
+      %mask_dim0 : index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+      vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//  CHECK-SAME:     %{{.*}}: vector<[16]x1xi8>,
+//  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
+//       CHECK:   %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
+
 /// Only masks operands of vector.create_mask are currently supported.
 func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,

@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir-vector

Author: Quinn Dawkins (qedawkins)

Changes

This does the same as #72142 for vector.transfer_write. Previously the pattern would silently drop the mask.


Full diff: https://github.com/llvm/llvm-project/pull/74038.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+21-17)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+44)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index d2c6ba557b9bbec..0dc097158a4a55d 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -260,14 +260,6 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
-  SmallVector<int64_t> reducedShape;
-  llvm::copy_if(shape, std::back_inserter(reducedShape),
-                [](int64_t dimSize) { return dimSize != 1; });
-  return reducedShape;
-}
-
 /// Converts OpFoldResults to int64_t shape without unit dims.
 static SmallVector<int64_t> getReducedShape(ArrayRef<OpFoldResult> mixedSizes) {
   SmallVector<int64_t> reducedShape;
@@ -446,9 +438,7 @@ class TransferWriteDropUnitDimsPattern
     Value source = transferWriteOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
     // TODO: support tensor type.
-    if (!sourceType || !sourceType.hasStaticShape())
-      return failure();
-    if (sourceType.getNumElements() != vectorType.getNumElements())
+    if (!sourceType)
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferWriteOp.hasOutOfBoundsDim())
@@ -461,25 +451,39 @@ class TransferWriteDropUnitDimsPattern
       return failure();
     // Check if the reduced vector shape matches the reduced destination shape.
     // Otherwise, this case is not supported yet.
-    int vectorReducedRank = getReducedRank(vectorType.getShape());
-    if (reducedRank != vectorReducedRank)
+    auto reducedVectorType = trimNonScalableUnitDims(vectorType);
+    if (reducedRank != reducedVectorType.getRank())
       return failure();
     if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) {
           return getConstantIntValue(v) != static_cast<int64_t>(0);
         }))
       return failure();
+
+    Value maskOp = transferWriteOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return rewriter.notifyMatchFailure(
+            transferWriteOp,
+            "unsupported mask op, only 'vector.create_mask' is "
+            "currently supported");
+      FailureOr<Value> rankReducedCreateMask =
+          createMaskDropNonScalableUnitDims(rewriter, loc, createMaskOp);
+      if (failed(rankReducedCreateMask))
+        return failure();
+      maskOp = *rankReducedCreateMask;
+    }
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
-    VectorType reducedVectorType = VectorType::get(
-        getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+    SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, reducedVectorType, vector);
     rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
-        transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap);
+        transferWriteOp, Type(), shapeCast, reducedShapeSource, zeros,
+        identityMap, maskOp, rewriter.getBoolArrayAttr(inBounds));
 
     return success();
   }
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 735915d43565389..d65708068862f46 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -144,6 +144,50 @@ func.func @masked_transfer_read_dynamic_rank_reducing_2(
 //       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, %[[DIM1]], 3, 1, %[[DIM4]], 1] [1, 1, 1, 1, 1, 1] : memref<1x?x3x1x?x1xi8, {{.*}}> to memref<?x3x?xi8, {{.*}}>
 //       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true, true, true]} : memref<?x3x?xi8, {{.*}}>, vector<[1]x3x[16]xi8>
 
+func.func @masked_transfer_write_and_vector_rank_reducing(
+      %arg : memref<1x1x3x1x16x1xf32>,
+      %vec : vector<1x3x1x16x1xf32>,
+      %mask_dim1 : index,
+      %mask_dim2 : index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %mask = vector.create_mask %c1, %mask_dim1, %c1, %mask_dim2, %c1 : vector<1x3x1x16x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0, %c0], %mask :
+      vector<1x3x1x16x1xf32>, memref<1x1x3x1x16x1xf32>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_and_vector_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<1x1x3x1x16x1xf32>
+//  CHECK-SAME:     {{.*}}: vector<1x3x1x16x1xf32>,
+//  CHECK-SAME:     %[[MASKDIM1:.+]]: index,
+//  CHECK-SAME:     %[[MASKDIM2:.+]]: index
+//       CHECK:   %[[MASK:.+]] = vector.create_mask %[[MASKDIM1]], %[[MASKDIM2]] : vector<3x16xi1>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0, 0] [1, 1, 3, 1, 16, 1] [1, 1, 1, 1, 1, 1]
+//  CHECK-SAME:     memref<1x1x3x1x16x1xf32> to memref<3x16xf32>
+//       CHECK:   vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}}, %[[MASK]] {in_bounds = [true, true]} : vector<3x16xf32>, memref<3x16xf32>
+
+func.func @masked_transfer_write_dynamic_rank_reducing(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %vec : vector<[16]x1xi8>,
+      %mask_dim0 : index) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
+    vector.transfer_write %vec, %arg[%c0, %c0], %mask {in_bounds = [true, true]} :
+      vector<[16]x1xi8>, memref<?x1xi8, strided<[?, ?], offset: ?>>
+    return
+}
+// CHECK-LABEL: func @masked_transfer_write_dynamic_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//  CHECK-SAME:     %{{.*}}: vector<[16]x1xi8>,
+//  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
+//       CHECK:   %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_write {{.*}}, %[[SUBVIEW]][%[[C0]]], %[[MASK]] {in_bounds = [true]} : vector<[16]xi8>, memref<?xi8, {{.*}}>
+
 /// Only masks operands of vector.create_mask are currently supported.
 func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
       %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@MacDue
Copy link
Member

MacDue commented Dec 1, 2023

Buildbot can't find my commit?

I've hit this a few times recently too, seems like something has changed 😕

@qedawkins
Copy link
Contributor Author

Buildbot can't find my commit?

I've hit this a few times recently too, seems like something has changed 😕

Well it worked now I guess :/

@qedawkins qedawkins merged commit fdf84cb into llvm:main Dec 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants