Skip to content

[mlir][Vector] Add a rewrite pattern for gather over a strided memref #72991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 30, 2023

Conversation

banach-space
Copy link
Contributor

This patch adds a rewrite pattern for vector.gather over a strided memref like the following:

%subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
    memref<100x3xf32> to memref<100xf32, strided<[3]>>
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
    memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
    memref<100x3xf32> into memref<300xf32>
%1 = arith.muli %arg3, %cst : vector<4xindex>
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
    memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>

Fixes iree-org/iree#15364.

This patch adds a rewrite pattern for `vector.gather` over a strided
memref like the following:

```mlir
%subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
    memref<100x3xf32> to memref<100xf32, strided<[3]>>
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
    memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

```mlir
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
    memref<100x3xf32> into memref<300xf32>
%1 = arith.muli %arg3, %cst : vector<4xindex>
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
    memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
    into vector<4xf32>
```

Fixes iree-org/iree#15364.
@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2023

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch adds a rewrite pattern for vector.gather over a strided memref like the following:

%subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
    memref&lt;100x3xf32&gt; to memref&lt;100xf32, strided&lt;[3]&gt;&gt;
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
    memref&lt;100xf32, strided&lt;[3]&gt;&gt;, vector&lt;4xindex&gt;, vector&lt;4xi1&gt;, vector&lt;4xf32&gt;
    into vector&lt;4xf32&gt;
%collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
    memref&lt;100x3xf32&gt; into memref&lt;300xf32&gt;
%1 = arith.muli %arg3, %cst : vector&lt;4xindex&gt;
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
    memref&lt;300xf32&gt;, vector&lt;4xindex&gt;, vector&lt;4xi1&gt;, vector&lt;4xf32&gt;
    into vector&lt;4xf32&gt;

Fixes iree-org/iree#15364.


Full diff: https://github.com/llvm/llvm-project/pull/72991.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp (+78-2)
  • (modified) mlir/test/Dialect/Vector/vector-gather-lowering.mlir (+54)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 152aefa65effc3d..54b350d7ac3524c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -96,6 +96,82 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
   }
 };
 
+/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
+/// MemRef with updated indices that model the strided access.
+///
+/// ```mlir
+/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>>
+/// ```
+/// ==>
+/// ```mlir
+/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32>
+/// %1 = arith.muli %idxs, %c3 : vector<4xindex>
+/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...)
+/// ```
+///
+/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
+/// but should be fairly straightforward to extend beyond that.
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::GatherOp op,
+                                PatternRewriter &rewriter) const override {
+    Value base = op.getBase();
+    if (!base.getDefiningOp())
+      return failure();
+
+    // TODO: Strided accesses might be coming from other ops as well
+    auto subview = dyn_cast<memref::SubViewOp>(base.getDefiningOp());
+    if (!subview)
+      return failure();
+
+    // TODO: Allows ranks > 2.
+    if (subview.getSource().getType().getRank() != 2)
+      return failure();
+
+    // Get strides
+    auto layout = subview.getResult().getType().getLayout();
+    auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
+
+    // TODO: Allow the access to be strided in multiple dimensions.
+    if (stridedLayoutAttr.getStrides().size() != 1)
+      return failure();
+
+    int64_t srcTrailingDim = subview.getSource().getType().getShape().back();
+
+    // Assume that the stride matches the trailing dimension of the source
+    // memref.
+    // TODO: Relax this assumption.
+    if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
+      return failure();
+
+    // 1. Collapse the input memref so that it's "flat".
+    SmallVector<ReassociationIndices> reassoc = {{0, 1}};
+    Value collapsed = rewriter.create<memref::CollapseShapeOp>(
+        op.getLoc(), subview.getSource(), reassoc);
+
+    // 2. Generate new gather indices that will model the
+    // strided access.
+    auto stride = rewriter.getIndexAttr(srcTrailingDim);
+    auto vType = op.getIndexVec().getType();
+    Value mulCst = rewriter.create<arith::ConstantOp>(
+        op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
+
+    Value newIdxs =
+        rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
+
+    // 3. Create an updated gather op with the collapsed input memref and the
+    // updated indices.
+    Value newGather = rewriter.create<vector::GatherOp>(
+        op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
+        newIdxs, op.getMask(), op.getPassThru());
+    rewriter.replaceOp(op, newGather);
+
+    return success();
+  }
+};
+
 /// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
 /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
 /// loads/extracts are made conditional using `scf.if` ops.
@@ -168,6 +244,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
 
 void mlir::vector::populateVectorGatherLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
-                                                          benefit);
+  patterns.add<FlattenGather, RemoveStrideFromGatherSource,
+               Gather1DToConditionalLoads>(patterns.getContext(), benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 026bec8cd65d3f5..3de7f44e4fb3e27 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -151,3 +151,57 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
   %0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
   return %0 : vector<2xf32>
 }
+
+// Check that vector.gather of a strided memref is replaced with a
+// vector.gather with indices encoding the original strides. Note that with the
+// other patterns
+#map = affine_map<()[s0] -> (s0 * 4096)>
+#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)>
+func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) {
+  %c0 = arith.constant 0 : index
+  %x_1 = affine.apply #map()[%x]
+  // Strided MemRef
+  %subview = memref.subview %M_in[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+  %cst_0 = arith.constant dense<true> : vector<4xi1>
+  %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+  // Gather of a strided MemRef
+  %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+  %subview_1 = memref.subview %M_out[%x_1] [%y] [1] : memref<518400xf32> to memref<?xf32, strided<[1], offset: ?>>
+  vector.store %7, %subview_1[%c0] : memref<?xf32, strided<[1], offset: ?>>, vector<4xf32>
+  return
+}
+// CHECK-LABEL:   func.func @strided_gather(
+// CHECK-SAME:                         %[[M_in:.*]]: memref<100x3xf32>,
+// CHECK-SAME:                         %[[M_out:.*]]: memref<518400xf32>,
+// CHECK-SAME:                         %[[IDXS:.*]]: vector<4xindex>,
+// CHECK-SAME:                         %[[VAL_4:.*]]: index,
+// CHECK-SAME:                         %[[VAL_5:.*]]: index) {
+// CHECK:           %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
+
+// CHECK:           %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
+// CHECK:           %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
+
+// CHECK:           %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK:             %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
+
+// CHECK:           %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK:             %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
+
+// CHECK:           %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK:             %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
+
+// CHECK:           %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
+// CHECK:           %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
+// CHECK:           scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK:             %[[M_3:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
+// CHECK:             %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>

Copy link
Member

@NicolaLancellotti NicolaLancellotti left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @banach-space , LGTM!

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems alright to me, though maybe let someone else have a quick check. Just a few little nits:

@hanhanW hanhanW self-requested a review November 27, 2023 19:47
Copy link

github-actions bot commented Nov 28, 2023

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 21646789497346a1a8dabb4b369e12db482b4daa 258cd7d1aea9a179ad1810e4cc1a5f06bd7ff729 -- mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 90128126d0..4465874f68 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -102,7 +102,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
 /// ```mlir
 ///   %subview = memref.subview %M (...)
 ///     : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+///   %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+///   strided<[3]>>
 /// ```
 /// ==>
 /// ```mlir

@banach-space
Copy link
Contributor Author

I've just sent an update and resolved threads that should be address by that update. Please re-open if I missed something 🙏🏻 .

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just one nit. Thanks!

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few things I think can be removed, but otherwise LGTM, cheers

@banach-space banach-space merged commit a383817 into llvm:main Nov 30, 2023
@banach-space banach-space deleted the andrzej/vector_gather_strided branch November 30, 2023 21:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

iree-compile fails because of the lowering to builtin.unrealized_conversion_cast operations
7 participants