@@ -109,17 +109,110 @@ struct LinearizeVectorizable final
109
109
}
110
110
};
111
111
112
- // / This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
113
- // / on a linearized vector.
114
- // / Following,
112
+ template <typename TOp>
113
+ static bool stridesAllOne (TOp op) {
114
+ static_assert (
115
+ std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116
+ std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117
+ " expected vector.extract_strided_slice or vector.insert_strided_slice" );
118
+ ArrayAttr strides = op.getStrides ();
119
+ return llvm::all_of (
120
+ strides, [](auto stride) { return isConstantIntValue (stride, 1 ); });
121
+ }
122
+
123
+ // / Convert an array of attributes into a vector of integers, if possible.
124
+ static FailureOr<SmallVector<int64_t >> intsFromArrayAttr (ArrayAttr attrs) {
125
+ if (!attrs)
126
+ return failure ();
127
+ SmallVector<int64_t > ints;
128
+ ints.reserve (attrs.size ());
129
+ for (auto attr : attrs) {
130
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
131
+ ints.push_back (intAttr.getInt ());
132
+ } else {
133
+ return failure ();
134
+ }
135
+ }
136
+ return ints;
137
+ }
138
+
139
+ // / Consider inserting a vector of shape `small` into a vector of shape `large`,
140
+ // / at position `offsets`: this function enumeratates all the indices in `large`
141
+ // / that are written to. The enumeration is with row-major ordering.
142
+ // /
143
+ // / Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
144
+ // / positions written to are (1,3) and (1,4), which have linearized indices 8
145
+ // / and 9. So [8,9] is returned.
146
+ // /
147
+ // / The length of the returned vector is equal to the number of elements in
148
+ // / the shape `small` (i.e. the product of dimensions of `small`).
149
+ SmallVector<int64_t > static getStridedSliceInsertionIndices (
150
+ ArrayRef<int64_t > small, ArrayRef<int64_t > large,
151
+ ArrayRef<int64_t > offsets) {
152
+
153
+ // Example of alignment between, `large`, `small` and `offsets`:
154
+ // large = 4, 5, 6, 7, 8
155
+ // small = 1, 6, 7, 8
156
+ // offsets = 2, 3, 0
157
+ //
158
+ // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
159
+ assert ((large.size () >= small.size ()) &&
160
+ " rank of 'large' cannot be lower than rank of 'small'" );
161
+ assert ((large.size () >= offsets.size ()) &&
162
+ " rank of 'large' cannot be lower than the number of offsets" );
163
+ unsigned delta = large.size () - small.size ();
164
+ unsigned nOffsets = offsets.size ();
165
+ auto getSmall = [&](int64_t i) -> int64_t {
166
+ return i >= delta ? small[i - delta] : 1 ;
167
+ };
168
+ auto getOffset = [&](int64_t i) -> int64_t {
169
+ return i < nOffsets ? offsets[i] : 0 ;
170
+ };
171
+
172
+ // Using 2 vectors of indices, at each iteration populate the updated set of
173
+ // indices based on the old set of indices, and the size of the small vector
174
+ // in the current iteration.
175
+ SmallVector<int64_t > indices{0 };
176
+ int64_t stride = 1 ;
177
+ for (int i = large.size () - 1 ; i >= 0 ; --i) {
178
+ int64_t currentSize = indices.size ();
179
+ int64_t smallSize = getSmall (i);
180
+ int64_t nextSize = currentSize * smallSize;
181
+ SmallVector<int64_t > nextIndices (nextSize);
182
+ int64_t *base = nextIndices.begin ();
183
+ int64_t offset = getOffset (i) * stride;
184
+ for (int j = 0 ; j < smallSize; ++j) {
185
+ for (int k = 0 ; k < currentSize; ++k) {
186
+ base[k] = indices[k] + offset;
187
+ }
188
+ offset += stride;
189
+ base += currentSize;
190
+ }
191
+ stride *= large[i];
192
+ indices = std::move (nextIndices);
193
+ }
194
+ return indices;
195
+ }
196
+
197
+ // / This pattern converts a vector.extract_strided_slice operation into a
198
+ // / vector.shuffle operation that has a rank-1 (linearized) operand and result.
199
+ // /
200
+ // / For example, the following:
201
+ // /
202
+ // / ```
115
203
// / vector.extract_strided_slice %source
116
204
// / { offsets = [..], strides = [..], sizes = [..] }
205
+ // / ```
206
+ // /
117
207
// / is converted to :
208
+ // / ```
118
209
// / %source_1d = vector.shape_cast %source
119
- // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
120
- // / %out_nd = vector.shape_cast %out_1d
121
- // / `shuffle_indices_1d` is computed using the offsets and sizes of the
122
- // / extraction.
210
+ // / %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
211
+ // / %out_nd = vector.shape_cast %out_1d
212
+ // / ```
213
+ // /
214
+ // / `shuffle_indices_1d` is computed using the offsets and sizes of the original
215
+ // / vector.extract_strided_slice operation.
123
216
struct LinearizeVectorExtractStridedSlice final
124
217
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
125
218
using OpConversionPattern::OpConversionPattern;
@@ -129,88 +222,116 @@ struct LinearizeVectorExtractStridedSlice final
129
222
: OpConversionPattern(typeConverter, context, benefit) {}
130
223
131
224
LogicalResult
132
- matchAndRewrite (vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
225
+ matchAndRewrite (vector::ExtractStridedSliceOp extractStridedSliceOp,
226
+ OpAdaptor adaptor,
133
227
ConversionPatternRewriter &rewriter) const override {
134
- VectorType dstType =
135
- getTypeConverter ()->convertType <VectorType>(extractOp.getType ());
136
- assert (dstType && " vector type destination expected." );
137
- if (extractOp.getVector ().getType ().isScalable () || dstType.isScalable ())
138
- return rewriter.notifyMatchFailure (extractOp,
139
- " scalable vectors are not supported." );
140
228
141
- ArrayAttr offsets = extractOp.getOffsets ();
142
- ArrayAttr sizes = extractOp.getSizes ();
143
- ArrayAttr strides = extractOp.getStrides ();
144
- if (!isConstantIntValue (strides[0 ], 1 ))
229
+ VectorType flatOutputType = getTypeConverter ()->convertType <VectorType>(
230
+ extractStridedSliceOp.getType ());
231
+ assert (flatOutputType && " vector type expected" );
232
+
233
+ // Expect a legalization failure if the strides are not all 1 (if ever the
234
+ // verifier for extract_strided_slice allows non-1 strides).
235
+ if (!stridesAllOne (extractStridedSliceOp)) {
145
236
return rewriter.notifyMatchFailure (
146
- extractOp, " Strided slice with stride != 1 is not supported." );
147
- Value srcVector = adaptor.getVector ();
148
- // If kD offsets are specified for nD source vector (n > k), the granularity
149
- // of the extraction is greater than 1. In this case last (n-k) dimensions
150
- // form the extraction granularity.
151
- // Example :
152
- // vector.extract_strided_slice %src {
153
- // offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
154
- // vector<4x8x8xf32> to vector<2x2x8xf32>
155
- // Here, extraction granularity is 8.
156
- int64_t extractGranularitySize = 1 ;
157
- int64_t nD = extractOp.getSourceVectorType ().getRank ();
158
- int64_t kD = (int64_t )offsets.size ();
159
- int64_t k = kD ;
160
- while (k < nD) {
161
- extractGranularitySize *= extractOp.getSourceVectorType ().getShape ()[k];
162
- ++k;
237
+ extractStridedSliceOp,
238
+ " extract_strided_slice with strides != 1 not supported" );
163
239
}
164
- // Get total number of extracted slices.
165
- int64_t nExtractedSlices = 1 ;
166
- for (Attribute size : sizes) {
167
- nExtractedSlices *= cast<IntegerAttr>(size).getInt ();
240
+
241
+ FailureOr<SmallVector<int64_t >> offsets =
242
+ intsFromArrayAttr (extractStridedSliceOp.getOffsets ());
243
+ if (failed (offsets)) {
244
+ return rewriter.notifyMatchFailure (extractStridedSliceOp,
245
+ " failed to get integer offsets" );
168
246
}
169
- // Compute the strides of the source vector considering first k dimensions.
170
- llvm::SmallVector<int64_t , 4 > sourceStrides (kD , extractGranularitySize);
171
- for (int i = kD - 2 ; i >= 0 ; --i) {
172
- sourceStrides[i] = sourceStrides[i + 1 ] *
173
- extractOp.getSourceVectorType ().getShape ()[i + 1 ];
247
+
248
+ ArrayRef<int64_t > inputShape =
249
+ extractStridedSliceOp.getSourceVectorType ().getShape ();
250
+
251
+ ArrayRef<int64_t > outputShape = extractStridedSliceOp.getType ().getShape ();
252
+
253
+ SmallVector<int64_t > indices = getStridedSliceInsertionIndices (
254
+ outputShape, inputShape, offsets.value ());
255
+
256
+ Value srcVector = adaptor.getVector ();
257
+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
258
+ extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
259
+ return success ();
260
+ }
261
+ };
262
+
263
+ // / This pattern converts a vector.insert_strided_slice operation into a
264
+ // / vector.shuffle operation that has rank-1 (linearized) operands and result.
265
+ // /
266
+ // / For example, the following:
267
+ // / ```
268
+ // / %0 = vector.insert_strided_slice %to_store, %into
269
+ // / {offsets = [1, 0, 0, 0], strides = [1, 1]}
270
+ // / : vector<2x2xi8> into vector<2x1x3x2xi8>
271
+ // / ```
272
+ // /
273
+ // / is converted to
274
+ // / ```
275
+ // / %to_store_1d
276
+ // / = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
277
+ // / %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
278
+ // / %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
279
+ // / %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
280
+ // / ```
281
+ // /
282
+ // / where shuffle_indices_1d in this case is
283
+ // / [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
284
+ // / ^^^^^^^^^^^^^^
285
+ // / to_store_1d
286
+ // /
287
+ struct LinearizeVectorInsertStridedSlice final
288
+ : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
289
+ using OpConversionPattern::OpConversionPattern;
290
+ LinearizeVectorInsertStridedSlice (const TypeConverter &typeConverter,
291
+ MLIRContext *context,
292
+ PatternBenefit benefit = 1 )
293
+ : OpConversionPattern(typeConverter, context, benefit) {}
294
+
295
+ LogicalResult
296
+ matchAndRewrite (vector::InsertStridedSliceOp insertStridedSliceOp,
297
+ OpAdaptor adaptor,
298
+ ConversionPatternRewriter &rewriter) const override {
299
+
300
+ // Expect a legalization failure if the strides are not all 1 (if ever the
301
+ // verifier for insert_strided_slice allows non-1 strides).
302
+ if (!stridesAllOne (insertStridedSliceOp)) {
303
+ return rewriter.notifyMatchFailure (
304
+ insertStridedSliceOp,
305
+ " insert_strided_slice with strides != 1 not supported" );
174
306
}
175
- // Final shuffle indices has nExtractedSlices * extractGranularitySize
176
- // elements.
177
- llvm::SmallVector<int64_t , 4 > indices (nExtractedSlices *
178
- extractGranularitySize);
179
- // Compute the strides of the extracted kD vector.
180
- llvm::SmallVector<int64_t , 4 > extractedStrides (kD , 1 );
181
- // Compute extractedStrides.
182
- for (int i = kD - 2 ; i >= 0 ; --i) {
183
- extractedStrides[i] =
184
- extractedStrides[i + 1 ] * cast<IntegerAttr>(sizes[i + 1 ]).getInt ();
307
+
308
+ VectorType inputType = insertStridedSliceOp.getValueToStore ().getType ();
309
+ ArrayRef<int64_t > inputShape = inputType.getShape ();
310
+
311
+ VectorType outputType = insertStridedSliceOp.getType ();
312
+ ArrayRef<int64_t > outputShape = outputType.getShape ();
313
+ int64_t nOutputElements = outputType.getNumElements ();
314
+
315
+ FailureOr<SmallVector<int64_t >> offsets =
316
+ intsFromArrayAttr (insertStridedSliceOp.getOffsets ());
317
+ if (failed (offsets)) {
318
+ return rewriter.notifyMatchFailure (insertStridedSliceOp,
319
+ " failed to get integer offsets" );
185
320
}
186
- // Iterate over all extracted slices from 0 to nExtractedSlices - 1
187
- // and compute the multi-dimensional index and the corresponding linearized
188
- // index within the source vector.
189
- for (int64_t i = 0 ; i < nExtractedSlices; ++i) {
190
- int64_t index = i;
191
- // Compute the corresponding multi-dimensional index.
192
- llvm::SmallVector<int64_t , 4 > multiDimIndex (kD , 0 );
193
- for (int64_t j = 0 ; j < kD ; ++j) {
194
- multiDimIndex[j] = (index / extractedStrides[j]);
195
- index -= multiDimIndex[j] * extractedStrides[j];
196
- }
197
- // Compute the corresponding linearized index in the source vector
198
- // i.e. shift the multiDimIndex by the offsets.
199
- int64_t linearizedIndex = 0 ;
200
- for (int64_t j = 0 ; j < kD ; ++j) {
201
- linearizedIndex +=
202
- (cast<IntegerAttr>(offsets[j]).getInt () + multiDimIndex[j]) *
203
- sourceStrides[j];
204
- }
205
- // Fill the indices array form linearizedIndex to linearizedIndex +
206
- // extractGranularitySize.
207
- for (int64_t j = 0 ; j < extractGranularitySize; ++j) {
208
- indices[i * extractGranularitySize + j] = linearizedIndex + j;
209
- }
321
+ SmallVector<int64_t > sliceIndices = getStridedSliceInsertionIndices (
322
+ inputShape, outputShape, offsets.value ());
323
+
324
+ SmallVector<int64_t > indices (nOutputElements);
325
+ std::iota (indices.begin (), indices.end (), 0 );
326
+ for (auto [index, sliceIndex] : llvm::enumerate (sliceIndices)) {
327
+ indices[sliceIndex] = index + nOutputElements;
210
328
}
211
- // Perform a shuffle to extract the kD vector.
212
- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
213
- extractOp, dstType, srcVector, srcVector, indices);
329
+
330
+ Value flatToStore = adaptor.getValueToStore ();
331
+ Value flatDest = adaptor.getDest ();
332
+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(insertStridedSliceOp,
333
+ flatDest.getType (), flatDest,
334
+ flatToStore, indices);
214
335
return success ();
215
336
}
216
337
};
@@ -296,7 +417,7 @@ struct LinearizeVectorExtract final
296
417
// Skip if result is not a vector type
297
418
if (!isa<VectorType>(extractOp.getType ()))
298
419
return rewriter.notifyMatchFailure (extractOp,
299
- " scalar extract is not supported. " );
420
+ " scalar extract not supported" );
300
421
Type dstTy = getTypeConverter ()->convertType (extractOp.getType ());
301
422
assert (dstTy && " expected 1-D vector type" );
302
423
@@ -453,8 +574,8 @@ struct LinearizeVectorSplat final
453
574
static bool isNotLinearizableBecauseScalable (Operation *op) {
454
575
455
576
bool unsupported =
456
- isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
457
- op);
577
+ isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
578
+ vector::ExtractOp, vector::InsertOp>( op);
458
579
if (!unsupported)
459
580
return false ;
460
581
@@ -539,6 +660,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
539
660
const TypeConverter &typeConverter, const ConversionTarget &target,
540
661
RewritePatternSet &patterns) {
541
662
patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
542
- LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
543
- typeConverter, patterns.getContext ());
663
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
664
+ LinearizeVectorInsertStridedSlice>(typeConverter,
665
+ patterns.getContext ());
544
666
}
0 commit comments