Skip to content

Commit 9768077

Browse files
authored
[mlir][vector] Update helpers in VectorEmulateNarrowType.cpp (nfc) (#131527)
Refactors the following pairs of helper hooks: * `dynamicallyInsertSubVector` + `staticallyInsertSubVector` * `dynamicallyExtractSubVector` + `staticallyExtractSubVector` These hooks are very similar, so I have unified the variable names and various conditions to make the actual differences clearer.
1 parent 3013458 commit 9768077

File tree

1 file changed

+112
-41
lines changed

1 file changed

+112
-41
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 112 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -198,85 +198,156 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
198198
return *newMask;
199199
}
200200

201-
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
202-
/// emitting `vector.extract_strided_slice`.
201+
/// Extracts 1-D subvector from a 1-D vector.
202+
///
203+
/// Given the input rank-1 source vector, extracts `numElemsToExtract` elements
204+
/// from `src`, starting at `offset`. The result is also a rank-1 vector:
205+
///
206+
/// vector<numElemsToExtract x !elemType>
207+
///
208+
/// (`!elType` is the element type of the source vector). As `offset` is a known
209+
/// _static_ value, this helper hook emits `vector.extract_strided_slice`.
210+
///
211+
/// EXAMPLE:
212+
/// %res = vector.extract_strided_slice %src
213+
/// { offsets = [offset], sizes = [numElemsToExtract], strides = [1] }
203214
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
204-
Value source, int64_t frontOffset,
205-
int64_t subvecSize) {
206-
auto vectorType = cast<VectorType>(source.getType());
207-
assert(vectorType.getRank() == 1 && "expected 1-D source types");
208-
assert(frontOffset + subvecSize <= vectorType.getNumElements() &&
215+
Value src, int64_t offset,
216+
int64_t numElemsToExtract) {
217+
auto vectorType = cast<VectorType>(src.getType());
218+
assert(vectorType.getRank() == 1 && "expected source to be rank-1-D vector ");
219+
assert(offset + numElemsToExtract <= vectorType.getNumElements() &&
209220
"subvector out of bounds");
210221

211-
// do not need extraction if the subvector size is the same as the source
212-
if (vectorType.getNumElements() == subvecSize)
213-
return source;
222+
// When extracting all available elements, just use the source vector as the
223+
// result.
224+
if (vectorType.getNumElements() == numElemsToExtract)
225+
return src;
214226

215-
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
216-
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
227+
auto offsets = rewriter.getI64ArrayAttr({offset});
228+
auto sizes = rewriter.getI64ArrayAttr({numElemsToExtract});
217229
auto strides = rewriter.getI64ArrayAttr({1});
218230

219231
auto resultVectorType =
220-
VectorType::get({subvecSize}, vectorType.getElementType());
232+
VectorType::get({numElemsToExtract}, vectorType.getElementType());
221233
return rewriter
222-
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, source,
234+
.create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src,
223235
offsets, sizes, strides)
224236
->getResult(0);
225237
}
226238

227-
/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
228-
/// at `offset`. it is a wrapper function for emitting
239+
/// Inserts 1-D subvector into a 1-D vector.
240+
///
241+
/// Inserts the input rank-1 source vector into the destination vector starting
242+
/// at `offset`. As `offset` is a known _static_ value, this helper hook emits
229243
/// `vector.insert_strided_slice`.
244+
///
245+
/// EXAMPLE:
246+
/// %res = vector.insert_strided_slice %src, %dest
247+
/// {offsets = [%offset], strides [1]}
230248
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
231249
Value src, Value dest, int64_t offset) {
232-
[[maybe_unused]] auto srcType = cast<VectorType>(src.getType());
233-
[[maybe_unused]] auto destType = cast<VectorType>(dest.getType());
234-
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
235-
"expected source and dest to be vector type");
250+
auto srcVecTy = cast<VectorType>(src.getType());
251+
auto destVecTy = cast<VectorType>(dest.getType());
252+
assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
253+
"expected source and dest to be rank-1 vector types");
254+
255+
// If overwritting the destination vector, just return the source.
256+
if (srcVecTy.getNumElements() == destVecTy.getNumElements() && offset == 0)
257+
return src;
258+
236259
auto offsets = rewriter.getI64ArrayAttr({offset});
237260
auto strides = rewriter.getI64ArrayAttr({1});
238-
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
261+
return rewriter.create<vector::InsertStridedSliceOp>(loc, destVecTy, src,
239262
dest, offsets, strides);
240263
}
241264

242-
/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
243-
/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
244-
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
245-
/// use it when `offset` cannot be folded into a constant value.
265+
/// Extracts 1-D subvector from a 1-D vector.
266+
///
267+
/// Given the input rank-1 source vector, extracts `numElemsToExtact` elements
268+
/// from `src`, starting at `offset`. The result is also a rank-1 vector:
269+
///
270+
/// vector<numElemsToExtact x !elType>
271+
///
272+
/// (`!elType` is the element type of the source vector). As `offset` is assumed
273+
/// to be a _dynamic_ SSA value, this helper method generates a sequence of
274+
/// `vector.extract` + `vector.insert` pairs.
275+
///
276+
/// EXAMPLE:
277+
/// %v1 = vector.extract %src[%offset] : i2 from vector<8xi2>
278+
/// %r1 = vector.insert %v1, %dest[0] : i2 into vector<3xi2>
279+
/// %c1 = arith.constant 1 : index
280+
/// %idx2 = arith.addi %offset, %c1 : index
281+
/// %v2 = vector.extract %src[%idx2] : i2 from vector<8xi2>
282+
/// %r2 = vector.insert %v2, %r1 [1] : i2 into vector<3xi2>
283+
/// (...)
246284
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
247-
Value source, Value dest,
285+
Value src, Value dest,
248286
OpFoldResult offset,
249-
int64_t numElementsToExtract) {
250-
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
251-
for (int i = 0; i < numElementsToExtract; ++i) {
287+
int64_t numElemsToExtract) {
288+
auto srcVecTy = cast<VectorType>(src.getType());
289+
assert(srcVecTy.getRank() == 1 && "expected source to be rank-1-D vector ");
290+
// NOTE: We are unable to take the offset into account in the following
291+
// assert, hence its still possible that the subvector is out-of-bounds even
292+
// if the condition is true.
293+
assert(numElemsToExtract <= srcVecTy.getNumElements() &&
294+
"subvector out of bounds");
295+
296+
// When extracting all available elements, just use the source vector as the
297+
// result.
298+
if (srcVecTy.getNumElements() == numElemsToExtract)
299+
return src;
300+
301+
for (int i = 0; i < numElemsToExtract; ++i) {
252302
Value extractLoc =
253303
(i == 0) ? offset.dyn_cast<Value>()
254304
: rewriter.create<arith::AddIOp>(
255305
loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
256306
rewriter.create<arith::ConstantIndexOp>(loc, i));
257-
auto extractOp =
258-
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
307+
auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, extractLoc);
259308
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
260309
}
261310
return dest;
262311
}
263312

264-
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
313+
/// Inserts 1-D subvector into a 1-D vector.
314+
///
315+
/// Inserts the input rank-1 source vector into the destination vector starting
316+
/// at `offset`. As `offset` is assumed to be a _dynamic_ SSA value, this hook
317+
/// uses a sequence of `vector.extract` + `vector.insert` pairs.
318+
///
319+
/// EXAMPLE:
320+
/// %v1 = vector.extract %src[0] : i2 from vector<8xi2>
321+
/// %r1 = vector.insert %v1, %dest[%offset] : i2 into vector<3xi2>
322+
/// %c1 = arith.constant 1 : index
323+
/// %idx2 = arith.addi %offset, %c1 : index
324+
/// %v2 = vector.extract %src[1] : i2 from vector<8xi2>
325+
/// %r2 = vector.insert %v2, %r1 [%idx2] : i2 into vector<3xi2>
326+
/// (...)
265327
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
266-
Value source, Value dest,
267-
OpFoldResult destOffsetVar,
268-
size_t length) {
269-
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
270-
assert(length > 0 && "length must be greater than 0");
271-
Value destOffsetVal =
272-
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
273-
for (size_t i = 0; i < length; ++i) {
328+
Value src, Value dest,
329+
OpFoldResult offset,
330+
int64_t numElemsToInsert) {
331+
auto srcVecTy = cast<VectorType>(src.getType());
332+
auto destVecTy = cast<VectorType>(dest.getType());
333+
assert(srcVecTy.getRank() == 1 && destVecTy.getRank() == 1 &&
334+
"expected source and dest to be rank-1 vector types");
335+
assert(numElemsToInsert > 0 &&
336+
"the number of elements to insert must be greater than 0");
337+
// NOTE: We are unable to take the offset into account in the following
338+
// assert, hence its still possible that the subvector is out-of-bounds even
339+
// if the condition is true.
340+
assert(numElemsToInsert <= destVecTy.getNumElements() &&
341+
"subvector out of bounds");
342+
343+
Value destOffsetVal = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
344+
for (int64_t i = 0; i < numElemsToInsert; ++i) {
274345
auto insertLoc = i == 0
275346
? destOffsetVal
276347
: rewriter.create<arith::AddIOp>(
277348
loc, rewriter.getIndexType(), destOffsetVal,
278349
rewriter.create<arith::ConstantIndexOp>(loc, i));
279-
auto extractOp = rewriter.create<vector::ExtractOp>(loc, source, i);
350+
auto extractOp = rewriter.create<vector::ExtractOp>(loc, src, i);
280351
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, insertLoc);
281352
}
282353
return dest;

0 commit comments

Comments
 (0)