@@ -96,6 +96,87 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
96
96
}
97
97
};
98
98
99
+ // / Rewrites a vector.gather of a strided MemRef as a gather of a non-strided
100
+ // / MemRef with updated indices that model the strided access.
101
+ // /
102
+ // / ```mlir
103
+ // / %subview = memref.subview %M (...)
104
+ // / : memref<100x3xf32> to memref<100xf32, strided<[3]>>
105
+ // / %gather = vector.gather %subview[%idxs] (...) : memref<100xf32, strided<[3]>>
106
+ // / ```
107
+ // / ==>
108
+ // / ```mlir
109
+ // / %collapse_shape = memref.collapse_shape %M (...)
110
+ // / : memref<100x3xf32> into memref<300xf32>
111
+ // / %new_idxs = arith.muli %idxs, %c3 : vector<4xindex>
112
+ // / %gather = vector.gather %collapse_shape[%new_idxs] (...)
113
+ // / : memref<300xf32> (...)
114
+ // / ```
115
+ // /
116
+ // / ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
117
+ // / but should be fairly straightforward to extend beyond that.
118
+ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
119
+ using OpRewritePattern::OpRewritePattern;
120
+
121
+ LogicalResult matchAndRewrite (vector::GatherOp op,
122
+ PatternRewriter &rewriter) const override {
123
+ Value base = op.getBase ();
124
+
125
+ // TODO: Strided accesses might be coming from other ops as well
126
+ auto subview = base.getDefiningOp <memref::SubViewOp>();
127
+ if (!subview)
128
+ return failure ();
129
+
130
+ auto sourceType = subview.getSource ().getType ();
131
+
132
+ // TODO: Allow ranks > 2.
133
+ if (sourceType.getRank () != 2 )
134
+ return failure ();
135
+
136
+ // Get strides
137
+ auto layout = subview.getResult ().getType ().getLayout ();
138
+ auto stridedLayoutAttr = llvm::dyn_cast<StridedLayoutAttr>(layout);
139
+ if (!stridedLayoutAttr)
140
+ return failure ();
141
+
142
+ // TODO: Allow the access to be strided in multiple dimensions.
143
+ if (stridedLayoutAttr.getStrides ().size () != 1 )
144
+ return failure ();
145
+
146
+ int64_t srcTrailingDim = sourceType.getShape ().back ();
147
+
148
+ // Assume that the stride matches the trailing dimension of the source
149
+ // memref.
150
+ // TODO: Relax this assumption.
151
+ if (stridedLayoutAttr.getStrides ()[0 ] != srcTrailingDim)
152
+ return failure ();
153
+
154
+ // 1. Collapse the input memref so that it's "flat".
155
+ SmallVector<ReassociationIndices> reassoc = {{0 , 1 }};
156
+ Value collapsed = rewriter.create <memref::CollapseShapeOp>(
157
+ op.getLoc (), subview.getSource (), reassoc);
158
+
159
+ // 2. Generate new gather indices that will model the
160
+ // strided access.
161
+ IntegerAttr stride = rewriter.getIndexAttr (srcTrailingDim);
162
+ VectorType vType = op.getIndexVec ().getType ();
163
+ Value mulCst = rewriter.create <arith::ConstantOp>(
164
+ op.getLoc (), vType, DenseElementsAttr::get (vType, stride));
165
+
166
+ Value newIdxs =
167
+ rewriter.create <arith::MulIOp>(op.getLoc (), op.getIndexVec (), mulCst);
168
+
169
+ // 3. Create an updated gather op with the collapsed input memref and the
170
+ // updated indices.
171
+ Value newGather = rewriter.create <vector::GatherOp>(
172
+ op.getLoc (), op.getResult ().getType (), collapsed, op.getIndices (),
173
+ newIdxs, op.getMask (), op.getPassThru ());
174
+ rewriter.replaceOp (op, newGather);
175
+
176
+ return success ();
177
+ }
178
+ };
179
+
99
180
// / Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or
100
181
// / `tensor.extract`s. To avoid out-of-bounds memory accesses, these
101
182
// / loads/extracts are made conditional using `scf.if` ops.
@@ -115,6 +196,16 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
115
196
116
197
Value condMask = op.getMask ();
117
198
Value base = op.getBase ();
199
+
200
+ // vector.load requires the most minor memref dim to have unit stride
201
+ if (auto memType = dyn_cast<MemRefType>(base.getType ())) {
202
+ if (auto stridesAttr =
203
+ dyn_cast_if_present<StridedLayoutAttr>(memType.getLayout ())) {
204
+ if (stridesAttr.getStrides ().back () != 1 )
205
+ return failure ();
206
+ }
207
+ }
208
+
118
209
Value indexVec = rewriter.createOrFold <arith::IndexCastOp>(
119
210
loc, op.getIndexVectorType ().clone (rewriter.getIndexType ()),
120
211
op.getIndexVec ());
@@ -168,6 +259,6 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
168
259
169
260
void mlir::vector::populateVectorGatherLoweringPatterns (
170
261
RewritePatternSet &patterns, PatternBenefit benefit) {
171
- patterns.add <FlattenGather, Gather1DToConditionalLoads>(patterns. getContext () ,
172
- benefit);
262
+ patterns.add <FlattenGather, RemoveStrideFromGatherSource ,
263
+ Gather1DToConditionalLoads>(patterns. getContext (), benefit);
173
264
}
0 commit comments