@@ -240,9 +240,10 @@ static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
240
240
// / function emits multiple `vector.extract` and `vector.insert` ops, so only
241
241
// / use it when `offset` cannot be folded into a constant value.
242
242
static Value dynamicallyExtractSubVector (OpBuilder &rewriter, Location loc,
243
- VectorValue source, Value dest,
243
+ Value source, Value dest,
244
244
OpFoldResult offset,
245
245
int64_t numElementsToExtract) {
246
+ assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
246
247
for (int i = 0 ; i < numElementsToExtract; ++i) {
247
248
Value extractLoc =
248
249
(i == 0 ) ? offset.dyn_cast <Value>()
@@ -258,9 +259,10 @@ static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
258
259
259
260
// / Inserts a 1-D subvector into a 1-D `dest` vector at index `destOffsetVar`.
260
261
static Value dynamicallyInsertSubVector (RewriterBase &rewriter, Location loc,
261
- VectorValue source, Value dest,
262
+ Value source, Value dest,
262
263
OpFoldResult destOffsetVar,
263
264
size_t length) {
265
+ assert (isa<VectorValue>(source) && " expected `source` to be a vector type" );
264
266
assert (length > 0 && " length must be greater than 0" );
265
267
Value destOffsetVal =
266
268
getValueOrCreateConstantIndexOp (rewriter, loc, destOffsetVar);
@@ -468,7 +470,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
468
470
469
471
auto memrefBase = cast<MemRefValue>(adaptor.getBase ());
470
472
471
- // Conditions when subbyte emulated store is not needed:
473
+ // Conditions when atomic RMWs are not needed:
472
474
// 1. The source vector size (in bits) is a multiple of byte size.
473
475
// 2. The address of the store is aligned to the emulated width boundary.
474
476
//
@@ -499,7 +501,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
499
501
// Destination: memref<12xi2>
500
502
// Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
501
503
//
502
- // MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
504
+ // Input MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
503
505
//
504
506
// Destination memref before:
505
507
//
@@ -817,9 +819,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
817
819
if (!foldedIntraVectorOffset) {
818
820
auto resultVector = rewriter.create <arith::ConstantOp>(
819
821
loc, op.getType (), rewriter.getZeroAttr (op.getType ()));
820
- result = dynamicallyExtractSubVector (
821
- rewriter, loc, cast<VectorValue>(result), resultVector ,
822
- linearizedInfo. intraDataOffset , origElements);
822
+ result = dynamicallyExtractSubVector (rewriter, loc, result, resultVector,
823
+ linearizedInfo. intraDataOffset ,
824
+ origElements);
823
825
} else if (!isAlignedEmulation) {
824
826
result = staticallyExtractSubvector (
825
827
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
@@ -938,8 +940,8 @@ struct ConvertVectorMaskedLoad final
938
940
loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
939
941
if (!foldedIntraVectorOffset) {
940
942
passthru = dynamicallyInsertSubVector (
941
- rewriter, loc, cast<VectorValue>( passthru) , emptyVector,
942
- linearizedInfo. intraDataOffset , origElements);
943
+ rewriter, loc, passthru, emptyVector, linearizedInfo. intraDataOffset ,
944
+ origElements);
943
945
} else if (!isAlignedEmulation) {
944
946
passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
945
947
*foldedIntraVectorOffset);
@@ -965,9 +967,9 @@ struct ConvertVectorMaskedLoad final
965
967
auto emptyMask = rewriter.create <arith::ConstantOp>(
966
968
loc, newSelectMaskType, rewriter.getZeroAttr (newSelectMaskType));
967
969
if (!foldedIntraVectorOffset) {
968
- mask = dynamicallyInsertSubVector (
969
- rewriter, loc, cast<VectorValue>(mask), emptyMask ,
970
- linearizedInfo. intraDataOffset , origElements);
970
+ mask = dynamicallyInsertSubVector (rewriter, loc, mask, emptyMask,
971
+ linearizedInfo. intraDataOffset ,
972
+ origElements);
971
973
} else if (!isAlignedEmulation) {
972
974
mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
973
975
*foldedIntraVectorOffset);
@@ -977,7 +979,7 @@ struct ConvertVectorMaskedLoad final
977
979
rewriter.create <arith::SelectOp>(loc, mask, bitCast, passthru);
978
980
if (!foldedIntraVectorOffset) {
979
981
result = dynamicallyExtractSubVector (
980
- rewriter, loc, cast<VectorValue>( result) , op.getPassThru (),
982
+ rewriter, loc, result, op.getPassThru (),
981
983
linearizedInfo.intraDataOffset , origElements);
982
984
} else if (!isAlignedEmulation) {
983
985
result = staticallyExtractSubvector (
0 commit comments