Skip to content

Commit a383817

Browse files
authored
[mlir][Vector] Add a rewrite pattern for gather over a strided memref (#72991)
This patch adds a rewrite pattern for `vector.gather` over a strided memref like the following: ```mlir %subview = memref.subview %arg0[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>> %gather = vector.gather %subview[%c0] [%idxs], %cst_0, %cst : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> ``` After the pattern added in this patch: ```mlir %collapse_shape = memref.collapse_shape %arg0 [[0, 1]] : memref<100x3xf32> into memref<300xf32> %1 = arith.muli %arg3, %cst : vector<4xindex> %gather = vector.gather %collapse_shape[%c0] [%1], %cst_1, %cst_0 : memref<300xf32>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32> ``` Fixes iree-org/iree#15364.
1 parent a51196e commit a383817

File tree

2 files changed

+148
-2
lines changed

2 files changed

+148
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,87 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
9696
}
9797
};
9898

99+
/// Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
100+
/// MemRef with updated indices that model the strided access.
101+
///
102+
/// ```mlir
103+
/// %subview = memref.subview %M (...)
104+
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
105+
/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
106+
/// ```
107+
/// ==>
108+
/// ```mlir
109+
/// %collapse_shape = memref.collapse_shape %M (...)
110+
/// : memref<100x3xf32> into memref<300xf32>
111+
/// %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
112+
/// %gather = vector.gather %collapse_shape[%new_idxs] (...)
113+
/// : memref<300xf32> (...)
114+
/// ```
115+
///
116+
/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
117+
/// but should be fairly straightforward to extend beyond that.
118+
struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
119+
using OpRewritePattern::OpRewritePattern;
120+
121+
LogicalResult matchAndRewrite(vector::GatherOp op,
122+
PatternRewriter &rewriter) const override {
123+
Value base = op.getBase();
124+
125+
// TODO: Strided accesses might be coming from other ops as well
126+
auto subview = base.getDefiningOp<memref::SubViewOp>();
127+
if (!subview)
128+
return failure();
129+
130+
auto sourceType = subview.getSource().getType();
131+
132+
// TODO: Allow ranks > 2.
133+
if (sourceType.getRank() != 2)
134+
return failure();
135+
136+
// Get strides
137+
auto layout = subview.getResult().getType().getLayout();
138+
auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
139+
if (!stridedLayoutAttr)
140+
return failure();
141+
142+
// TODO: Allow the access to be strided in multiple dimensions.
143+
if (stridedLayoutAttr.getStrides().size() != 1)
144+
return failure();
145+
146+
int64_t srcTrailingDim = sourceType.getShape().back();
147+
148+
// Assume that the stride matches the trailing dimension of the source
149+
// memref.
150+
// TODO: Relax this assumption.
151+
if (stridedLayoutAttr.getStrides()[0] != srcTrailingDim)
152+
return failure();
153+
154+
// 1. Collapse the input memref so that it's "flat".
155+
SmallVector<ReassociationIndices> reassoc = {{0, 1}};
156+
Value collapsed = rewriter.create<memref::CollapseShapeOp>(
157+
op.getLoc(), subview.getSource(), reassoc);
158+
159+
// 2. Generate new gather indices that will model the
160+
// strided access.
161+
IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim);
162+
VectorType vType = op.getIndexVec().getType();
163+
Value mulCst = rewriter.create<arith::ConstantOp>(
164+
op.getLoc(), vType, DenseElementsAttr::get(vType, stride));
165+
166+
Value newIdxs =
167+
rewriter.create<arith::MulIOp>(op.getLoc(), op.getIndexVec(), mulCst);
168+
169+
// 3. Create an updated gather op with the collapsed input memref and the
170+
// updated indices.
171+
Value newGather = rewriter.create<vector::GatherOp>(
172+
op.getLoc(), op.getResult().getType(), collapsed, op.getIndices(),
173+
newIdxs, op.getMask(), op.getPassThru());
174+
rewriter.replaceOp(op, newGather);
175+
176+
return success();
177+
}
178+
};
179+
99180
/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
100181
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
101182
/// loads/extracts are made conditional using `scf.if` ops.
@@ -115,6 +196,16 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
115196

116197
Value condMask = op.getMask();
117198
Value base = op.getBase();
199+
200+
// vector.load requires the most minor memref dim to have unit stride
201+
if (auto memType = dyn_cast<MemRefType>(base.getType())) {
202+
if (auto stridesAttr =
203+
dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout())) {
204+
if (stridesAttr.getStrides().back() != 1)
205+
return failure();
206+
}
207+
}
208+
118209
Value indexVec = rewriter.createOrFold<arith::IndexCastOp>(
119210
loc, op.getIndexVectorType().clone(rewriter.getIndexType()),
120211
op.getIndexVec());
@@ -168,6 +259,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
168259

169260
void mlir::vector::populateVectorGatherLoweringPatterns(
170261
RewritePatternSet &patterns, PatternBenefit benefit) {
171-
patterns.add<FlattenGather, Gather1DToConditionalLoads>(patterns.getContext(),
172-
benefit);
262+
patterns.add<FlattenGather, RemoveStrideFromGatherSource,
263+
Gather1DToConditionalLoads>(patterns.getContext(), benefit);
173264
}

mlir/test/Dialect/Vector/vector-gather-lowering.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,58 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
151151
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru : tensor<?xf32>, vector<2xindex>, vector<2xi1>, vector<2xf32> into vector<2xf32>
152152
return %0 : vector<2xf32>
153153
}
154+
155+
// Check that vector.gather of a strided memref is replaced with a
156+
// vector.gather with indices encoding the original strides. Note that multiple
157+
// patterns are run for this example, e.g.:
158+
// 1. "remove stride from gather source"
159+
// 2. "flatten gather"
160+
// However, the main goal is to the test Pattern 1 above.
161+
#map = affine_map<()[s0] -> (s0 * 4096)>
162+
func.func @strided_gather(%base : memref<100x3xf32>,
163+
%idxs : vector<4xindex>,
164+
%x : index, %y : index) -> vector<4xf32> {
165+
%c0 = arith.constant 0 : index
166+
%x_1 = affine.apply #map()[%x]
167+
// Strided MemRef
168+
%subview = memref.subview %base[0, 0] [100, 1] [1, 1] : memref<100x3xf32> to memref<100xf32, strided<[3]>>
169+
%mask = arith.constant dense<true> : vector<4xi1>
170+
%pass_thru = arith.constant dense<0.000000e+00> : vector<4xf32>
171+
// Gather of a strided MemRef
172+
%res = vector.gather %subview[%c0] [%idxs], %mask, %pass_thru : memref<100xf32, strided<[3]>>, vector<4xindex>, vector<4xi1>, vector<4xf32> into vector<4xf32>
173+
return %res : vector<4xf32>
174+
}
175+
// CHECK-LABEL: func.func @strided_gather(
176+
// CHECK-SAME: %[[base:.*]]: memref<100x3xf32>,
177+
// CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
178+
// CHECK-SAME: %[[VAL_4:.*]]: index,
179+
// CHECK-SAME: %[[VAL_5:.*]]: index) -> vector<4xf32> {
180+
// CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
181+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
182+
183+
// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[base]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
184+
// CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
185+
186+
// CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
187+
// CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
188+
// CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
189+
// CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]][%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
190+
// CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
191+
192+
// CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
193+
// CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
194+
// CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
195+
// CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]][%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
196+
// CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
197+
198+
// CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
199+
// CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
200+
// CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
201+
// CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
202+
// CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
203+
204+
// CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
205+
// CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
206+
// CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
207+
// CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]][%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
208+
// CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>

0 commit comments

Comments
 (0)