Skip to content

Commit d458e3b

Browse files
committed
another update according to comments
1 parent 50f0786 commit d458e3b

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,10 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
240240
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
241241
/// use it when `offset` cannot be folded into a constant value.
242242
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
243-
VectorValue source, Value dest,
243+
Value source, Value dest,
244244
OpFoldResult offset,
245245
int64_t numElementsToExtract) {
246+
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
246247
for (int i = 0; i < numElementsToExtract; ++i) {
247248
Value extractLoc =
248249
(i == 0) ? offset.dyn_cast<Value>()
@@ -258,9 +259,10 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
258259

259260
/// Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
260261
static Value dynamicallyInsertSubVector(RewriterBase &rewriter, Location loc,
261-
VectorValue source, Value dest,
262+
Value source, Value dest,
262263
OpFoldResult destOffsetVar,
263264
size_t length) {
265+
assert(isa<VectorValue>(source) && "expected `source` to be a vector type");
264266
assert(length > 0 && "length must be greater than 0");
265267
Value destOffsetVal =
266268
getValueOrCreateConstantIndexOp(rewriter, loc, destOffsetVar);
@@ -468,7 +470,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
468470

469471
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
470472

471-
// Conditions when subbyte emulated store is not needed:
473+
// Conditions when atomic RMWs are not needed:
472474
// 1. The source vector size (in bits) is a multiple of byte size.
473475
// 2. The address of the store is aligned to the emulated width boundary.
474476
//
@@ -499,7 +501,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
499501
// Destination: memref<12xi2>
500502
// Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
501503
//
502-
// MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
504+
// Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
503505
//
504506
// Destination memref before:
505507
//
@@ -817,9 +819,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
817819
if (!foldedIntraVectorOffset) {
818820
auto resultVector = rewriter.create<arith::ConstantOp>(
819821
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
820-
result = dynamicallyExtractSubVector(
821-
rewriter, loc, cast<VectorValue>(result), resultVector,
822-
linearizedInfo.intraDataOffset, origElements);
822+
result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector,
823+
linearizedInfo.intraDataOffset,
824+
origElements);
823825
} else if (!isAlignedEmulation) {
824826
result = staticallyExtractSubvector(
825827
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
@@ -938,8 +940,8 @@ struct ConvertVectorMaskedLoad final
938940
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
939941
if (!foldedIntraVectorOffset) {
940942
passthru = dynamicallyInsertSubVector(
941-
rewriter, loc, cast<VectorValue>(passthru), emptyVector,
942-
linearizedInfo.intraDataOffset, origElements);
943+
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
944+
origElements);
943945
} else if (!isAlignedEmulation) {
944946
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
945947
*foldedIntraVectorOffset);
@@ -965,9 +967,9 @@ struct ConvertVectorMaskedLoad final
965967
auto emptyMask = rewriter.create<arith::ConstantOp>(
966968
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
967969
if (!foldedIntraVectorOffset) {
968-
mask = dynamicallyInsertSubVector(
969-
rewriter, loc, cast<VectorValue>(mask), emptyMask,
970-
linearizedInfo.intraDataOffset, origElements);
970+
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
971+
linearizedInfo.intraDataOffset,
972+
origElements);
971973
} else if (!isAlignedEmulation) {
972974
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
973975
*foldedIntraVectorOffset);
@@ -977,7 +979,7 @@ struct ConvertVectorMaskedLoad final
977979
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);
978980
if (!foldedIntraVectorOffset) {
979981
result = dynamicallyExtractSubVector(
980-
rewriter, loc, cast<VectorValue>(result), op.getPassThru(),
982+
rewriter, loc, result, op.getPassThru(),
981983
linearizedInfo.intraDataOffset, origElements);
982984
} else if (!isAlignedEmulation) {
983985
result = staticallyExtractSubvector(

0 commit comments

Comments
 (0)