Skip to content

[mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs #72142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 20, 2023

Conversation

c-rhodes
Copy link
Collaborator

@c-rhodes c-rhodes commented Nov 13, 2023

This patch extends TransferReadDropUnitDimsPattern to support dropping
unit dims from partially-static memrefs, for example:

%v = vector.transfer_read %base[%c0, %c0], %pad {in_bounds = [true, true]} :
  memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>

Is rewritten as:

%dim0 = memref.dim %base, %c0 : memref<?x1xi8, strided<[?, ?], offset: ?>>
%subview = memref.subview %base[0, 0] [%dim0, 1] [1, 1] :
  memref<?x1xi8, strided<[?, ?], offset: ?>> to memref<?xi8, #map1>
%v = vector.transfer_read %subview[%c0], %pad {in_bounds = [true]}
  : memref<?xi8, #map1>, vector<[16]xi8>

Scalable vectors are now also supported, the scalable dims were being
dropped when creating the rank-reduced vector type. The xfer op can also
have a mask of type 'vector.create_mask', which gets rewritten as long
as the mask of the unit dim is a constant of 1.

This patch extends the vector.transfer_read drop unit dim pattern to
support scalable vectors with (non-scalable) unit dims, and dynamic
memrefs. The xfer op can also have a mask of type 'vector.create_mask',
which gets rewritten as long as the mask of the unit dim is a constant
of 1.
@c-rhodes
Copy link
Collaborator Author

For context, this and #72105 enable the lowering of a regular linalg.matmul to ArmSME (#72144). The integration test in #72144 can't currently be lowered because of the following sequence (taken from the inner loop):

%subview = memref.subview %arg0[%arg3, %arg5] [%2, 1] [1, 1] : memref<?x?xf32, strided<[?, ?], offset: ?>> to memref<?x1xf32, strided<[?, ?], offset: ?>>
%mask = vector.create_mask %2, %c1 : vector<[4]x1xi1>
%0 = vector.transfer_read %subview[%c0, %c0], %pad, %mask {in_bounds = [true, true]} : memref<?x1xf32, strided<[?, ?], offset: ?>>, vector<[4]x1xf32>
%1 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
%2 = vector.extract %1[0] : vector<[4]xf32> from vector<1x[4]xf32>

the vector.transfer_read is of a rank-2 vector type with leading scalable dim. This can't be lowered in the generic path, as although arrays can now contain scalable vectors in LLVM, they can only contain a fixed quantity of scalable vectors, thus type conversion cannot happen.

With this change the unit-dim can be dropped such that the transfer_read is of a rank-1 vector<[4]xf32>.

I've posted this as a draft as I'm not entirely sure if semantically this transformation is correct when I look at the memref dialect and subview op [1], that mentions:

A subview operation may additionally reduce the rank of the resulting view by removing dimensions that are statically known to be of size 1.

but the vector.transfer_read is in-bounds.

Would appreciate any thoughts / feedback.

cc @dcaballe @nicolasvasilache @banach-space @MacDue

[1] https://mlir.llvm.org/docs/Dialects/MemRef/#memrefsubview-memrefsubviewop

@c-rhodes c-rhodes changed the title [mlir][vector] Extend vector.transfer_read drop unit dim pattern [mlir][vector] Extend TransferReadDropUnitDimsPattern to support partially-static memrefs Nov 14, 2023
@c-rhodes c-rhodes marked this pull request as ready for review November 14, 2023 09:44
@c-rhodes c-rhodes requested a review from dcaballe November 14, 2023 09:44
@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Cullen Rhodes (c-rhodes)

Changes

This patch extends TransferReadDropUnitDimsPattern to support dropping
unit dims from partially-static memrefs, for example:

%v = vector.transfer_read %base[%c0, %c0], %pad {in_bounds = [true, true]} :
  memref&lt;?x1xi8, strided&lt;[?, ?], offset: ?&gt;&gt;, vector&lt;[16]x1xi8&gt;

Is rewritten as:

%dim0 = memref.dim %base, %c0 : memref&lt;?x1xi8, strided&lt;[?, ?], offset: ?&gt;&gt;
%subview = memref.subview %base[0, 0] [%dim0, 1] [1, 1] :
  memref&lt;?x1xi8, strided&lt;[?, ?], offset: ?&gt;&gt; to memref&lt;?xi8, #map1&gt;
%v = vector.transfer_read %subview[%c0], %pad {in_bounds = [true]}
  : memref&lt;?xi8, #map1&gt;, vector&lt;[16]xi8&gt;

Scalable vectors are now also supported, the scalable dims were being
dropped when creating the rank-reduced vector type. The xfer op can also
have a mask of type 'vector.create_mask', which gets rewritten as long
as the mask of the unit dim is a constant of 1.


Full diff: https://github.com/llvm/llvm-project/pull/72142.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp (+69-29)
  • (modified) mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir (+86)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index a5f1b28152b9bde..95445f2081ec89c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -260,14 +260,22 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
   opToErase.push_back(read.getOperation());
 }
 
+/// Returns a copy of `shape` without unit dims.
+static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
+  SmallVector<int64_t> reducedShape;
+  llvm::copy_if(shape, std::back_inserter(reducedShape),
+                [](int64_t dimSize) { return dimSize != 1; });
+  return reducedShape;
+}
+
 /// Drops unit dimensions from the input MemRefType.
-static MemRefType dropUnitDims(MemRefType inputType, ArrayRef<int64_t> offsets,
-                               ArrayRef<int64_t> sizes,
-                               ArrayRef<int64_t> strides) {
-  SmallVector<int64_t> targetShape = llvm::to_vector(
-      llvm::make_filter_range(sizes, [](int64_t sz) { return sz != 1; }));
+static MemRefType dropUnitDims(MemRefType inputType,
+                               ArrayRef<OpFoldResult> offsets,
+                               ArrayRef<OpFoldResult> sizes,
+                               ArrayRef<OpFoldResult> strides) {
   Type rankReducedType = memref::SubViewOp::inferRankReducedResultType(
-      targetShape, inputType, offsets, sizes, strides);
+      getReducedShape(inputType.getShape()), inputType, offsets, sizes,
+      strides);
   return canonicalizeStridedLayout(cast<MemRefType>(rankReducedType));
 }
 
@@ -277,17 +285,18 @@ static Value rankReducingSubviewDroppingUnitDims(PatternRewriter &rewriter,
                                                  mlir::Location loc,
                                                  Value input) {
   MemRefType inputType = cast<MemRefType>(input.getType());
-  assert(inputType.hasStaticShape());
-  SmallVector<int64_t> subViewOffsets(inputType.getRank(), 0);
-  SmallVector<int64_t> subViewStrides(inputType.getRank(), 1);
-  ArrayRef<int64_t> subViewSizes = inputType.getShape();
-  MemRefType resultType =
-      dropUnitDims(inputType, subViewOffsets, subViewSizes, subViewStrides);
+  SmallVector<OpFoldResult> offsets(inputType.getRank(),
+                                    rewriter.getIndexAttr(0));
+  SmallVector<OpFoldResult> sizes = memref::getMixedSizes(rewriter, loc, input);
+  SmallVector<OpFoldResult> strides(inputType.getRank(),
+                                    rewriter.getIndexAttr(1));
+  MemRefType resultType = dropUnitDims(inputType, offsets, sizes, strides);
+
   if (canonicalizeStridedLayout(resultType) ==
       canonicalizeStridedLayout(inputType))
     return input;
-  return rewriter.create<memref::SubViewOp>(
-      loc, resultType, input, subViewOffsets, subViewSizes, subViewStrides);
+  return rewriter.create<memref::SubViewOp>(loc, resultType, input, offsets,
+                                            sizes, strides);
 }
 
 /// Returns the number of dims that aren't unit dims.
@@ -295,12 +304,18 @@ static int getReducedRank(ArrayRef<int64_t> shape) {
   return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; });
 }
 
-/// Returns a copy of `shape` without unit dims.
-static SmallVector<int64_t> getReducedShape(ArrayRef<int64_t> shape) {
-  SmallVector<int64_t> reducedShape;
-  llvm::copy_if(shape, std::back_inserter(reducedShape),
-                [](int64_t dimSize) { return dimSize != 1; });
-  return reducedShape;
+/// Trims non-scalable one dimensions from `oldType` and returns the result
+/// type.
+static VectorType trimUnitDims(VectorType oldType) {
+  SmallVector<int64_t> newShape;
+  SmallVector<bool> newScalableDims;
+  for (auto [dimIdx, dimSize] : llvm::enumerate(oldType.getShape())) {
+    if (dimSize == 1 && !oldType.getScalableDims()[dimIdx])
+      continue;
+    newShape.push_back(dimSize);
+    newScalableDims.push_back(oldType.getScalableDims()[dimIdx]);
+  }
+  return VectorType::get(newShape, oldType.getElementType(), newScalableDims);
 }
 
 namespace {
@@ -320,9 +335,7 @@ class TransferReadDropUnitDimsPattern
     Value source = transferReadOp.getSource();
     MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
     // TODO: support tensor types.
-    if (!sourceType || !sourceType.hasStaticShape())
-      return failure();
-    if (sourceType.getNumElements() != vectorType.getNumElements())
+    if (!sourceType)
       return failure();
     // TODO: generalize this pattern, relax the requirements here.
     if (transferReadOp.hasOutOfBoundsDim())
@@ -335,23 +348,50 @@ class TransferReadDropUnitDimsPattern
       return failure();
     // Check if the reduced vector shape matches the reduced source shape.
     // Otherwise, this case is not supported yet.
-    int vectorReducedRank = getReducedRank(vectorType.getShape());
-    if (reducedRank != vectorReducedRank)
+    auto reducedVectorType = trimUnitDims(vectorType);
+    if (reducedRank != reducedVectorType.getRank())
       return failure();
     if (llvm::any_of(transferReadOp.getIndices(), [](Value v) {
           return getConstantIntValue(v) != static_cast<int64_t>(0);
         }))
       return failure();
+
+    auto maskOp = transferReadOp.getMask();
+    if (maskOp) {
+      auto createMaskOp = maskOp.getDefiningOp<vector::CreateMaskOp>();
+      if (!createMaskOp)
+        return failure();
+      auto maskType = maskOp.getType();
+      auto reducedMaskType = trimUnitDims(maskType);
+      if (reducedMaskType.getRank() == maskType.getRank())
+        return failure();
+      SmallVector<Value> maskOperands;
+      for (auto [dim, dimIsScalable, maskOperand] :
+           llvm::zip(maskType.getShape(), maskType.getScalableDims(),
+                     createMaskOp.getOperands())) {
+        if (dim == 1 && !dimIsScalable) {
+          // If the mask for the unit dim is not a constant of 1, do nothing.
+          auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
+          if (!constant || (constant.value() != 1))
+            return failure();
+          continue;
+        }
+        maskOperands.push_back(maskOperand);
+      }
+      maskOp = rewriter.create<vector::CreateMaskOp>(loc, reducedMaskType,
+                                                     maskOperands);
+    }
+
     Value reducedShapeSource =
         rankReducingSubviewDroppingUnitDims(rewriter, loc, source);
     Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     SmallVector<Value> zeros(reducedRank, c0);
     auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank);
-    auto reducedVectorType = VectorType::get(
-        getReducedShape(vectorType.getShape()), vectorType.getElementType());
-
+    SmallVector<bool> inBounds(reducedVectorType.getRank(), true);
     auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
-        loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
+        loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
+        transferReadOp.getPadding(), maskOp,
+        rewriter.getBoolArrayAttr(inBounds));
     auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
         loc, vectorType, newTransferReadOp);
     rewriter.replaceOp(transferReadOp, shapeCast);
diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
index 2852e301888cca8..688fcd114041812 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir
@@ -82,6 +82,92 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
 //       CHECK:   %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
 //       CHECK:   vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
 
+func.func @transfer_read_dynamic_rank_reducing(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0], %pad {in_bounds = [true, true]} :
+      memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+    return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @transfer_read_dynamic_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+func.func @masked_transfer_read_dynamic_rank_reducing(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %mask_dim0 : index) -> vector<[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x1xi1>
+    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+    return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//  CHECK-SAME:     %[[MASK_DIM0:.+]]: index
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[PAD:.+]] = arith.constant 0 : i8
+//       CHECK:   %[[MASK:.+]] = vector.create_mask %[[MASK_DIM0]] : vector<[16]xi1>
+//       CHECK:   %[[DIM0:.+]] = memref.dim %[[ARG]], %[[C0]] : memref<?x1xi8, strided<[?, ?], offset: ?>>
+//       CHECK:   %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0] [%[[DIM0]], 1] [1, 1] : memref<?x1xi8, {{.*}}> to memref<?xi8, {{.*}}>
+//       CHECK:   vector.transfer_read %[[SUBVIEW]][{{.*}}], %[[PAD]], %[[MASK]] {in_bounds = [true]} : memref<?xi8, {{.*}}>, vector<[16]xi8>
+
+/// Only vector.create_mask is currently supported.
+func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_1(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %mask : vector<[16]x1xi1>) -> vector<[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %pad = arith.constant 0 : i8
+    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+    return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_1
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//   CHECK-NOT: vector.create_mask
+//   CHECK-NOT: memref.subview
+//       CHECK: vector.transfer_read %[[ARG]]
+
+/// Unit dim mask must be constant of 1.
+func.func @unsupported_masked_transfer_read_dynamic_rank_reducing_2(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %mask_dim0 : index, %mask_dim1 : index) -> vector<[16]x1xi8> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.create_mask %mask_dim0, %mask_dim1 : vector<[16]x1xi1>
+    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>
+    return %v : vector<[16]x1xi8>
+}
+// CHECK-LABEL: func @unsupported_masked_transfer_read_dynamic_rank_reducing_2
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//   CHECK-NOT: memref.subview
+//       CHECK: vector.transfer_read {{.*}} vector<[16]x1xi8>
+
+/// Unit dim must be non-scalable.
+func.func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim(
+      %arg : memref<?x1xi8, strided<[?, ?], offset: ?>>,
+      %mask_dim0 : index) -> vector<[16]x[1]xi8> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %pad = arith.constant 0 : i8
+    %mask = vector.create_mask %mask_dim0, %c1 : vector<[16]x[1]xi1>
+    %v = vector.transfer_read %arg[%c0, %c0], %pad, %mask {in_bounds = [true, true]} :
+      memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x[1]xi8>
+    return %v : vector<[16]x[1]xi8>
+}
+// CHECK-LABEL: func @masked_transfer_read_dynamic_rank_reducing_scalable_unit_dim
+//  CHECK-SAME:     %[[ARG:.+]]: memref<?x1xi8
+//   CHECK-NOT: memref.subview
+//       CHECK: vector.transfer_read {{.*}} vector<[16]x[1]xi8>
+
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%func_op: !transform.op<"func.func"> {transform.readonly}) {
     transform.apply_patterns to %func_op {

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me :). The docs you link seem to say this is allowed, as this is only removing statically known unit dims. The added shape_cast won't be something that can be lowered, but should (hopefully) fold away.

https://mlir.llvm.org/docs/Dialects/MemRef/#memrefsubview-memrefsubviewop

A subview operation may additionally reduce the rank of the resulting view by removing dimensions that are statically known to be of size 1.

Comment on lines 373 to 377
// If the mask for the unit dim is not a constant of 1, do nothing.
auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
return failure();
continue;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably done elsewhere, but if any dim in the mask is 0 this whole read folds to a constant splat of the padding.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't checked but I figured that should already exist and isn't something to handle here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we hoist this logic in a helper with a good name? It seems this is deep enough already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that came up in another review recently. I think we have a canonicalization pattern for that already?

@c-rhodes
Copy link
Collaborator Author

This seems reasonable to me :). The docs you link seem to say this is allowed, as this is only removing statically known unit dims. The added shape_cast won't be something that can be lowered, but should (hopefully) fold away.

can confirm the shape_cast folds away :)

Comment on lines 373 to 377
// If the mask for the unit dim is not a constant of 1, do nothing.
auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
return failure();
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we hoist this logic in a helper with a good name? It seems this is deep enough already.

auto newTransferReadOp = rewriter.create<vector::TransferReadOp>(
loc, reducedVectorType, reducedShapeSource, zeros, identityMap);
loc, reducedVectorType, reducedShapeSource, zeros, identityMap,
transferReadOp.getPadding(), maskOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding the previously omitted mask!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

It would be good to add tests with non-trailing unit dim and when there's more than one unit dim (and perhaps a mix of scalable and non-scalable).

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines 373 to 377
// If the mask for the unit dim is not a constant of 1, do nothing.
auto constant = maskOperand.getDefiningOp<arith::ConstantIndexOp>();
if (!constant || (constant.value() != 1))
return failure();
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that came up in another review recently. I think we have a canonicalization pattern for that already?

* trimUnitDims -> trimNonScalableUnitDims
* add notifyMatchFailure on unsupported mask op
* llvm::zip -> llvm::zip_equal
* add helper to rewrite vector.create_mask to drop non-scalable unit dims.
* add getReducedShape that takes mixedSizes.
* add a more complex test.
@c-rhodes
Copy link
Collaborator Author

It would be good to add tests with non-trailing unit dim and when there's more than one unit dim (and perhaps a mix of scalable and non-scalable).

I've added a more complex test

@c-rhodes c-rhodes merged commit bf897d5 into llvm:main Nov 20, 2023
@c-rhodes c-rhodes deleted the mlir-drop-unit-dims-dynamic-scalable branch November 20, 2023 08:39
sr-tream pushed a commit to sr-tream/llvm-project that referenced this pull request Nov 20, 2023
…ially-static memrefs (llvm#72142)

This patch extends TransferReadDropUnitDimsPattern to support dropping
unit dims from partially-static memrefs, for example:

%v = vector.transfer_read %base[%c0, %c0], %pad {in_bounds = [true, true]} :
  memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>

Is rewritten as:

%dim0 = memref.dim %base, %c0 : memref<?x1xi8, strided<[?, ?], offset: ?>>
%subview = memref.subview %base[0, 0] [%dim0, 1] [1, 1] :
  memref<?x1xi8, strided<[?, ?], offset: ?>> to memref<?xi8, #map1>
%v = vector.transfer_read %subview[%c0], %pad {in_bounds = [true]}
  : memref<?xi8, #map1>, vector<[16]xi8>

Scalable vectors are now also supported, the scalable dims were being
dropped when creating the rank-reduced vector type. The xfer op can also
have a mask of type 'vector.create_mask', which gets rewritten as long
as the mask of the unit dim is a constant of 1.
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…ially-static memrefs (llvm#72142)

This patch extends TransferReadDropUnitDimsPattern to support dropping
unit dims from partially-static memrefs, for example:

%v = vector.transfer_read %base[%c0, %c0], %pad {in_bounds = [true, true]} :
  memref<?x1xi8, strided<[?, ?], offset: ?>>, vector<[16]x1xi8>

Is rewritten as:

%dim0 = memref.dim %base, %c0 : memref<?x1xi8, strided<[?, ?], offset: ?>>
%subview = memref.subview %base[0, 0] [%dim0, 1] [1, 1] :
  memref<?x1xi8, strided<[?, ?], offset: ?>> to memref<?xi8, #map1>
%v = vector.transfer_read %subview[%c0], %pad {in_bounds = [true]}
  : memref<?xi8, #map1>, vector<[16]xi8>

Scalable vectors are now also supported, the scalable dims were being
dropped when creating the rank-reduced vector type. The xfer op can also
have a mask of type 'vector.create_mask', which gets rewritten as long
as the mask of the unit dim is a constant of 1.
qedawkins added a commit to qedawkins/llvm-project that referenced this pull request Dec 1, 2023
This does the same as llvm#72142 for vector.transfer_write. Previously the
pattern would silently drop the mask.
qedawkins added a commit to qedawkins/llvm-project that referenced this pull request Dec 1, 2023
This does the same as llvm#72142 for vector.transfer_write. Previously the
pattern would silently drop the mask.
qedawkins added a commit that referenced this pull request Dec 1, 2023
This does the same as #72142 for vector.transfer_write. Previously the
pattern would silently drop the mask.
qedawkins added a commit to iree-org/llvm-project that referenced this pull request Dec 1, 2023
This does the same as llvm#72142 for vector.transfer_write. Previously the
pattern would silently drop the mask.
stellaraccident pushed a commit to iree-org/llvm-project that referenced this pull request Dec 7, 2023
This does the same as llvm#72142 for vector.transfer_write. Previously the
pattern would silently drop the mask.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants