-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add a rewrite pattern for gather over a strided memref #72991
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Add a rewrite pattern for gather over a strided memref #72991
Conversation
b1fbe79
to
d5db22c
Compare
This patch adds a rewrite pattern for `vector.gather` over a strided memref like the following: ```mlir %subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>> %gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> ``` ```mlir %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : memref<100x3xf32> into memref<300xf32> %1 = arith.muli %arg3, %cst : vector<4xindex> %gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 : memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> ``` Fixes iree-org/iree#15364.
d5db22c
to
c652617
Compare
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesThis patch adds a rewrite pattern for %subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] :
memref<100x3xf32> to memref<100xf32, strided<[3]>>
%gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst :
memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32>
into vector<4xf32> %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] :
memref<100x3xf32> into memref<300xf32>
%1 = arith.muli %arg3, %cst : vector<4xindex>
%gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 :
memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32>
into vector<4xf32> Fixes iree-org/iree#15364. Full diff: https://github.com/llvm/llvm-project/pull/72991.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 152aefa65effc3d..54b350d7ac3524c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -96,6 +96,82 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
}
};
+/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
+/// MemRef with updated indices that model the strided access.
+///
+/// ```mlir
+/// %subview = memref.subview %M (...) to memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview (...) : memref<100xf32, strided<[3]>>
+/// ```
+/// ==>
+/// ```mlir
+/// %collapse_shape = memref.collapse_shape %M (...) into memref<300xf32>
+/// %1 = arith.muli %idxs, %c3 : vector<4xindex>
+/// %gather = vector.gather %collapse_shape (...) : memref<300xf32> (...)
+/// ```
+///
+/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
+/// but should be fairly straightforward to extend beyond that.
+struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::GatherOp op,
+ PatternRewriter &rewriter) const override {
+ Value base = op.getBase();
+ if (!base.getDefiningOp())
+ return failure();
+
+ // TODO: Strided accesses might be coming from other ops as well
+ auto subview = dyn_cast<memref::SubViewOp>(base.getDefiningOp());
+ if (!subview)
+ return failure();
+
+ // TODO: Allows ranks > 2.
+ if (subview.getSource().getType().getRank() != 2)
+ return failure();
+
+ // Get strides
+ auto layout = subview.getResult().getType().getLayout();
+ auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
+
+ // TODO: Allow the access to be strided in multiple dimensions.
+ if (stridedLayoutAttr.getStrides().size() != 1)
+ return failure();
+
+ int64_t srcTrailingDim = subview.getSource().getType().getShape().back();
+
+ // Assume that the stride matches the trailing dimension of the source
+ // memref.
+ // TODO: Relax this assumption.
+ if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
+ return failure();
+
+ // 1. Collapse the input memref so that it's "flat".
+ SmallVector<ReassociationIndices> reassoc = {{0, 1}};
+ Value collapsed = rewriter.create<memref::CollapseShapeOp>(
+ op.getLoc(), subview.getSource(), reassoc);
+
+ // 2. Generate new gather indices that will model the
+ // strided access.
+ auto stride = rewriter.getIndexAttr(srcTrailingDim);
+ auto vType = op.getIndexVec().getType();
+ Value mulCst = rewriter.create<arith::ConstantOp>(
+ op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
+
+ Value newIdxs =
+ rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
+
+ // 3. Create an updated gather op with the collapsed input memref and the
+ // updated indices.
+ Value newGather = rewriter.create<vector::GatherOp>(
+ op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
+ newIdxs, op.getMask(), op.getPassThru());
+ rewriter.replaceOp(op, newGather);
+
+ return success();
+ }
+};
+
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
@@ -168,6 +244,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
void mlir::vector::populateVectorGatherLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
- benefit);
+ patterns.add<FlattenGather, RemoveStrideFromGatherSource,
+ Gather1DToConditionalLoads>(patterns.getContext(), benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
index 026bec8cd65d3f5..3de7f44e4fb3e27 100644
--- a/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
+++ b/mlir/test/Dialect/Vector/vector-gather-lowering.mlir
@@ -151,3 +151,57 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
return %0 : vector<2xf32>
}
+
+// Check that vector.gather of a strided memref is replaced with a
+// vector.gather with indices encoding the original strides. Note that with the
+// other patterns
+#map = affine_map<()[s0] -> (s0 * 4096)>
+#map1 = affine_map<()[s0] -> (s0 * -4096 + 518400, 4096)>
+func.func @strided_gather(%M_in : memref<100x3xf32>, %M_out: memref<518400xf32>, %idxs : vector<4xindex>, %x : index, %y : index) {
+ %c0 = arith.constant 0 : index
+ %x_1 = affine.apply #map()[%x]
+ // Strided MemRef
+ %subview = memref.subview %M_in[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
+ %cst_0 = arith.constant dense<true> : vector<4xi1>
+ %cst = arith.constant dense<0.000000e+00> : vector<4xf32>
+ // Gather of a strided MemRef
+ %7 = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
+ %subview_1 = memref.subview %M_out[%x_1] [%y] [1] : memref<518400xf32> to memref<?xf32, strided<[1], offset: ?>>
+ vector.store %7, %subview_1[%c0] : memref<?xf32, strided<[1], offset: ?>>, vector<4xf32>
+ return
+}
+// CHECK-LABEL: func.func @strided_gather(
+// CHECK-SAME: %[[M_in:.*]]: memref<100x3xf32>,
+// CHECK-SAME: %[[M_out:.*]]: memref<518400xf32>,
+// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
+// CHECK-SAME: %[[VAL_4:.*]]: index,
+// CHECK-SAME: %[[VAL_5:.*]]: index) {
+// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
+
+// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
+// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
+
+// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
+// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
+// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
+
+// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
+// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
+// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
+
+// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
+// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
+// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
+
+// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
+// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
+// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
+// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
+// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @banach-space , LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems alright to me, though maybe let someone else have a quick check. Just a few little nits:
… memref Refine based on PR feedback
You can test this locally with the following command:git-clang-format --diff 21646789497346a1a8dabb4b369e12db482b4daa 258cd7d1aea9a179ad1810e4cc1a5f06bd7ff729 -- mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 90128126d0..4465874f68 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -102,7 +102,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
/// ```mlir
/// %subview = memref.subview %M (...)
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
-/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
+/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
+/// strided<[3]>>
/// ```
/// ==>
/// ```mlir
|
… memref Fix formatting
I've just sent an update and resolved threads that should be address by that update. Please re-open if I missed something 🙏🏻 . |
… memref Restrict Gather1DToConditionalLoads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just one nit. Thanks!
… memref Update comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few things I think can be removed, but otherwise LGTM, cheers
… memref Remove unnecessary code
This patch adds a rewrite pattern for
vector.gather
over a strided memref like the following:Fixes iree-org/iree#15364.