@@ -198,85 +198,156 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
198
198
return *newMask;
199
199
}
200
200
201
- // / Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
202
- // / emitting `vector.extract_strided_slice`.
201
+ // / Extracts 1-D subvector from a 1-D vector.
202
+ // /
203
+ // / Given the input rank-1 source vector, extracts `numElemsToExtract` elements
204
+ // / from `src`, starting at `offset`. The result is also a rank-1 vector:
205
+ // /
206
+ // / vector<numElemsToExtract x !elemType>
207
+ // /
208
+ // / (`!elType` is the element type of the source vector). As `offset` is a known
209
+ // / _static_ value, this helper hook emits `vector.extract_strided_slice`.
210
+ // /
211
+ // / EXAMPLE:
212
+ // / %res = vector.extract_strided_slice %src
213
+ // / { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
203
214
static Value staticallyExtractSubvector (OpBuilder &rewriter, Location loc,
204
- Value source , int64_t frontOffset ,
205
- int64_t subvecSize ) {
206
- auto vectorType = cast<VectorType>(source .getType ());
207
- assert (vectorType.getRank () == 1 && " expected 1-D source types " );
208
- assert (frontOffset + subvecSize <= vectorType.getNumElements () &&
215
+ Value src , int64_t offset ,
216
+ int64_t numElemsToExtract ) {
217
+ auto vectorType = cast<VectorType>(src .getType ());
218
+ assert (vectorType.getRank () == 1 && " expected source to be rank- 1-D vector " );
219
+ assert (offset + numElemsToExtract <= vectorType.getNumElements () &&
209
220
" subvector out of bounds" );
210
221
211
- // do not need extraction if the subvector size is the same as the source
212
- if (vectorType.getNumElements () == subvecSize)
213
- return source;
222
+ // When extracting all available elements, just use the source vector as the
223
+ // result.
224
+ if (vectorType.getNumElements () == numElemsToExtract)
225
+ return src;
214
226
215
- auto offsets = rewriter.getI64ArrayAttr ({frontOffset });
216
- auto sizes = rewriter.getI64ArrayAttr ({subvecSize });
227
+ auto offsets = rewriter.getI64ArrayAttr ({offset });
228
+ auto sizes = rewriter.getI64ArrayAttr ({numElemsToExtract });
217
229
auto strides = rewriter.getI64ArrayAttr ({1 });
218
230
219
231
auto resultVectorType =
220
- VectorType::get ({subvecSize }, vectorType.getElementType ());
232
+ VectorType::get ({numElemsToExtract }, vectorType.getElementType ());
221
233
return rewriter
222
- .create <vector::ExtractStridedSliceOp>(loc, resultVectorType, source ,
234
+ .create <vector::ExtractStridedSliceOp>(loc, resultVectorType, src ,
223
235
offsets, sizes, strides)
224
236
->getResult (0 );
225
237
}
226
238
227
- // / Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
228
- // / at `offset`. it is a wrapper function for emitting
239
+ // / Inserts 1-D subvector into a 1-D vector.
240
+ // /
241
+ // / Inserts the input rank-1 source vector into the destination vector starting
242
+ // / at `offset`. As `offset` is a known _static_ value, this helper hook emits
229
243
// / `vector.insert_strided_slice`.
244
+ // /
245
+ // / EXAMPLE:
246
+ // / %res = vector.insert_strided_slice %src, %dest
247
+ // / {offsets = [%offset], strides [1]}
230
248
static Value staticallyInsertSubvector (OpBuilder &rewriter, Location loc,
231
249
Value src, Value dest, int64_t offset) {
232
- [[maybe_unused]] auto srcType = cast<VectorType>(src.getType ());
233
- [[maybe_unused]] auto destType = cast<VectorType>(dest.getType ());
234
- assert (srcType.getRank () == 1 && destType.getRank () == 1 &&
235
- " expected source and dest to be vector type" );
250
+ auto srcVecTy = cast<VectorType>(src.getType ());
251
+ auto destVecTy = cast<VectorType>(dest.getType ());
252
+ assert (srcVecTy.getRank () == 1 && destVecTy.getRank () == 1 &&
253
+ " expected source and dest to be rank-1 vector types" );
254
+
255
+ // If overwritting the destination vector, just return the source.
256
+ if (srcVecTy.getNumElements () == destVecTy.getNumElements () && offset == 0 )
257
+ return src;
258
+
236
259
auto offsets = rewriter.getI64ArrayAttr ({offset});
237
260
auto strides = rewriter.getI64ArrayAttr ({1 });
238
- return rewriter.create <vector::InsertStridedSliceOp>(loc, dest. getType () , src,
261
+ return rewriter.create <vector::InsertStridedSliceOp>(loc, destVecTy , src,
239
262
dest, offsets, strides);
240
263
}
241
264
242
- // / Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
243
- // / and size `numElementsToExtract`, and inserts into the `dest` vector. This
244
- // / function emits multiple `vector.extract` and `vector.insert` ops, so only
245
- // / use it when `offset` cannot be folded into a constant value.
265
+ // / Extracts 1-D subvector from a 1-D vector.
266
+ // /
267
+ // / Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268
+ // / from `src`, starting at `offset`. The result is also a rank-1 vector:
269
+ // /
270
+ // / vector<numElemsToExtact x !elType>
271
+ // /
272
+ // / (`!elType` is the element type of the source vector). As `offset` is assumed
273
+ // / to be a _dynamic_ SSA value, this helper method generates a sequence of
274
+ // / `vector.extract` + `vector.insert` pairs.
275
+ // /
276
+ // / EXAMPLE:
277
+ // / %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278
+ // / %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279
+ // / %c1 = arith.constant 1 : index
280
+ // / %idx2 = arith.addi %offset, %c1 : index
281
+ // / %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282
+ // / %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283
+ // / (...)
246
284
static Value dynamicallyExtractSubVector (OpBuilder &rewriter, Location loc,
247
- Value source , Value dest,
285
+ Value src , Value dest,
248
286
OpFoldResult offset,
249
- int64_t numElementsToExtract) {
250
- assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
251
- for (int i = 0 ; i < numElementsToExtract; ++i) {
287
+ int64_t numElemsToExtract) {
288
+ auto srcVecTy = cast<VectorType>(src.getType ());
289
+ assert (srcVecTy.getRank () == 1 && " expected source to be rank-1-D vector " );
290
+ // NOTE: We are unable to take the offset into account in the following
291
+ // assert, hence its still possible that the subvector is out-of-bounds even
292
+ // if the condition is true.
293
+ assert (numElemsToExtract <= srcVecTy.getNumElements () &&
294
+ " subvector out of bounds" );
295
+
296
+ // When extracting all available elements, just use the source vector as the
297
+ // result.
298
+ if (srcVecTy.getNumElements () == numElemsToExtract)
299
+ return src;
300
+
301
+ for (int i = 0 ; i < numElemsToExtract; ++i) {
252
302
Value extractLoc =
253
303
(i == 0 ) ? offset.dyn_cast <Value>()
254
304
: rewriter.create <arith::AddIOp>(
255
305
loc, rewriter.getIndexType (), offset.dyn_cast <Value>(),
256
306
rewriter.create <arith::ConstantIndexOp>(loc, i));
257
- auto extractOp =
258
- rewriter.create <vector::ExtractOp>(loc, source, extractLoc);
307
+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, src, extractLoc);
259
308
dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, i);
260
309
}
261
310
return dest;
262
311
}
263
312
264
- // / Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
313
+ // / Inserts 1-D subvector into a 1-D vector.
314
+ // /
315
+ // / Inserts the input rank-1 source vector into the destination vector starting
316
+ // / at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
317
+ // / uses a sequence of `vector.extract` + `vector.insert` pairs.
318
+ // /
319
+ // / EXAMPLE:
320
+ // / %v1 = vector.extract %src[0] : i2 from vector<8xi2>
321
+ // / %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
322
+ // / %c1 = arith.constant 1 : index
323
+ // / %idx2 = arith.addi %offset, %c1 : index
324
+ // / %v2 = vector.extract %src[1] : i2 from vector<8xi2>
325
+ // / %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
326
+ // / (...)
265
327
static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
266
- Value source, Value dest,
267
- OpFoldResult destOffsetVar,
268
- size_t length) {
269
- assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
270
- assert (length > 0 && " length must be greater than 0" );
271
- Value destOffsetVal =
272
- getValueOrCreateConstantIndexOp (rewriter, loc, destOffsetVar);
273
- for (size_t i = 0 ; i < length; ++i) {
328
+ Value src, Value dest,
329
+ OpFoldResult offset,
330
+ int64_t numElemsToInsert) {
331
+ auto srcVecTy = cast<VectorType>(src.getType ());
332
+ auto destVecTy = cast<VectorType>(dest.getType ());
333
+ assert (srcVecTy.getRank () == 1 && destVecTy.getRank () == 1 &&
334
+ " expected source and dest to be rank-1 vector types" );
335
+ assert (numElemsToInsert > 0 &&
336
+ " the number of elements to insert must be greater than 0" );
337
+ // NOTE: We are unable to take the offset into account in the following
338
+ // assert, hence its still possible that the subvector is out-of-bounds even
339
+ // if the condition is true.
340
+ assert (numElemsToInsert <= destVecTy.getNumElements () &&
341
+ " subvector out of bounds" );
342
+
343
+ Value destOffsetVal = getValueOrCreateConstantIndexOp (rewriter, loc, offset);
344
+ for (int64_t i = 0 ; i < numElemsToInsert; ++i) {
274
345
auto insertLoc = i == 0
275
346
? destOffsetVal
276
347
: rewriter.create <arith::AddIOp>(
277
348
loc, rewriter.getIndexType (), destOffsetVal,
278
349
rewriter.create <arith::ConstantIndexOp>(loc, i));
279
- auto extractOp = rewriter.create <vector::ExtractOp>(loc, source , i);
350
+ auto extractOp = rewriter.create <vector::ExtractOp>(loc, src , i);
280
351
dest = rewriter.create <vector::InsertOp>(loc, extractOp, dest, insertLoc);
281
352
}
282
353
return dest;
0 commit comments