Skip to content

[mlir][vector] Determine vector sizes from the result shape in the ca… #88249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 17, 2024

Conversation

pashu123
Copy link
Member

…se of tensor pack

When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.

@llvmbot
Copy link
Member

llvmbot commented Apr 10, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Prashant Kumar (pashu123)

Changes

…se of tensor pack

When the vector sizes are not passed as inputs to the vector transform operation, the vector sizes are queried from the static result shape in the case of tensor.pack op.


Full diff: https://github.com/llvm/llvm-project/pull/88249.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+16-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (+34)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 25785653a71675..422fc0562f9003 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1525,6 +1525,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
   (void)status; // prevent unused variable warning on non-assert builds.
   assert(succeeded(status) && "failed to reify result shapes");
 
+  ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
+
+  // If the input vector sizes are not provided, then the vector sizes are
+  // determined by the result tensor shape.
+  if (inputVectorSizes.empty()) {
+    // Make sure that the result tensor shape is static.
+    if (ShapedType::isDynamicShape(resultTensorShape))
+      return failure();
+    inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
+  }
+
   // Create masked TransferReadOp.
   SmallVector<int64_t> inputShape(inputVectorSizes);
   auto innerTiles = packOp.getStaticInnerTiles();
@@ -1763,7 +1774,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
 /// Returns success if `inputVectorSizes` is a valid masking configuraion for
 /// given `shape`, i.e., it meets:
 ///   1. The numbers of elements in both array are equal.
-///   2. `inputVectorSizes` does nos have dynamic dimensions.
+///   2. `inputVectorSizes` does not have dynamic dimensions.
 ///   3. All the values in `inputVectorSizes` are greater than or equal to
 ///      static sizes in `shape`.
 static LogicalResult
@@ -1881,18 +1892,19 @@ static LogicalResult vectorizeLinalgOpPrecondition(
   return success();
 }
 
-/// TODO: Use a matcher to check for a constant padding value.
 static LogicalResult
 vectorizePackOpPrecondition(tensor::PackOp packOp,
                             ArrayRef<int64_t> inputVectorSizes) {
   auto padValue = packOp.getPaddingValue();
-  if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
+  Attribute cstAttr;
+  if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
     LDBG("pad value is not constant: " << packOp << "\n");
     return failure();
   }
 
   ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
-  if (failed(isValidMaskedInputVector(
+  if (!inputVectorSizes.empty() &&
+      failed(isValidMaskedInputVector(
           resultTensorShape.take_front(packOp.getSourceRank()),
           inputVectorSizes)))
     return failure();
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 2d01d57304013c..f354ab9ea0b0a3 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -812,3 +812,37 @@ func.func @test_vectorize_unpack_no_masks(%source: tensor<8x8x32x16xf32>, %dest:
     transform.yield
   } 
 }
+
+  // -----
+
+// CHECK-LABEL: test_vectorize_padded_pack_no_vector_sizes
+func.func @test_vectorize_padded_pack_no_vector_sizes(%arg0: tensor<32x7x15xf32>, %arg1: tensor<32x4x1x16x2xf32>) -> tensor<32x4x1x16x2xf32> {
+  %pad = arith.constant 0.000000e+00 : f32
+  %pack = tensor.pack %arg0 padding_value(%pad : f32) inner_dims_pos = [2, 1] inner_tiles = [16, 2] into %arg1 : tensor<32x7x15xf32> -> tensor<32x4x1x16x2xf32>
+  return %pack : tensor<32x4x1x16x2xf32>
+}
+//  CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32
+//  CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index
+//  CHECK-DAG: %[[c7:.*]] = arith.constant 7 : index
+//  CHECK-DAG: %[[c15:.*]] = arith.constant 15 : index
+//      CHECK: %[[mask:.*]] = vector.create_mask %[[c32]], %[[c7]], %[[c15]] : vector<32x8x16xi1>
+//      CHECK: %[[masked_read:.*]] = vector.mask %[[mask]] {
+// CHECK-SAME:   vector.transfer_read %{{.*}}[%[[c0]], %[[c0]], %[[c0]]], %[[cst]]
+// CHECK-SAME:   {in_bounds = [true, true, true]} : tensor<32x7x15xf32>, vector<32x8x16xf32>
+// CHECK-SAME: } : vector<32x8x16xi1> -> vector<32x8x16xf32>
+//      CHECK: %[[shape_cast:.*]] = vector.shape_cast %[[masked_read]] : vector<32x8x16xf32> to vector<32x4x2x1x16xf32>
+//      CHECK: %[[transpose:.*]] = vector.transpose %[[shape_cast]], [0, 1, 3, 4, 2] : vector<32x4x2x1x16xf32> to vector<32x4x1x16x2xf32>
+//  CHECK-DAG: %[[c0_1:.*]] = arith.constant 0 : index
+//  CHECK-DAG: %[[empty:.*]] = tensor.empty() : tensor<32x4x1x16x2xf32>
+//      CHECK: %[[write:.*]] = vector.transfer_write %[[transpose]], %[[empty]][%[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]], %[[c0_1]]]
+// CHECK-SAME:   {in_bounds = [true, true, true, true, true]} : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+//      CHECK: return %[[write]] : tensor<32x4x1x16x2xf32>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["tensor.pack"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %0 : !transform.any_op
+    transform.yield 
+  }
+}

@hanhanW hanhanW requested review from Max191 and dcaballe April 10, 2024 16:35
@pashu123 pashu123 requested a review from hanhanW April 12, 2024 12:21
@pashu123 pashu123 force-pushed the stat_pack branch 2 times, most recently from 4a5dbc3 to 05c23ec Compare April 15, 2024 08:59
@pashu123 pashu123 force-pushed the stat_pack branch 2 times, most recently from 2932ae2 to 7bfa06d Compare April 16, 2024 09:40
@pashu123 pashu123 requested a review from hanhanW April 16, 2024 09:42
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just two nits. Thanks for pushing on this!

@hanhanW hanhanW requested a review from chelini April 16, 2024 16:47
…se of tensor pack

When the vector sizes are not passed as inputs to the vector transform operation,
the vector sizes are queried from the static result shape in the case of tensor.pack op.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants