@@ -305,12 +305,14 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
305
305
assert (
306
306
downcastType.getNumElements () * downcastType.getElementTypeBitWidth () ==
307
307
upcastType.getNumElements () * upcastType.getElementTypeBitWidth () &&
308
- " expected upcastType size to be twice the size of downcastType " );
309
- if (trueValue.getType () != downcastType)
308
+ " expected input and output number of bits to match " );
309
+ if (trueValue.getType () != downcastType) {
310
310
trueValue = builder.create <vector::BitCastOp>(loc, downcastType, trueValue);
311
- if (falseValue.getType () != downcastType)
311
+ }
312
+ if (falseValue.getType () != downcastType) {
312
313
falseValue =
313
314
builder.create <vector::BitCastOp>(loc, downcastType, falseValue);
315
+ }
314
316
Value selectedType =
315
317
builder.create <arith::SelectOp>(loc, mask, trueValue, falseValue);
316
318
// Upcast the selected value to the new type.
@@ -454,28 +456,33 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
454
456
stridedMetadata.getConstifiedMixedStrides (),
455
457
getAsOpFoldResult (adaptor.getIndices ()));
456
458
457
- auto foldedNumFrontPadElems =
459
+ std::optional< int64_t > foldedNumFrontPadElems =
458
460
isUnalignedEmulation
459
461
? getConstantIntValue (linearizedInfo.intraDataOffset )
460
462
: 0 ;
461
463
462
464
if (!foldedNumFrontPadElems) {
463
- // Unimplemented case for dynamic front padding size != 0
464
- return failure ( );
465
+ return failure ( " subbyte store emulation: dynamic front padding size is "
466
+ " not yet implemented " );
465
467
}
466
468
467
- auto linearizedMemref = cast<MemRefValue>(adaptor.getBase ());
469
+ auto memrefBase = cast<MemRefValue>(adaptor.getBase ());
468
470
469
- // Shortcut: conditions when subbyte store at the front is not needed:
471
+ // Shortcut: conditions when subbyte emulated store at the front is not
472
+ // needed:
470
473
// 1. The source vector size is multiple of byte size
471
474
// 2. The address of the store is aligned to the emulated width boundary
475
+ //
476
+ // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
477
+ // need unaligned emulation because the store address is aligned and the
478
+ // source is a whole byte.
472
479
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
473
480
auto numElements = origElements / numSrcElemsPerDest;
474
481
auto bitCast = rewriter.create <vector::BitCastOp>(
475
482
loc, VectorType::get (numElements, newElementType),
476
483
op.getValueToStore ());
477
484
rewriter.replaceOpWithNewOp <vector::StoreOp>(
478
- op, bitCast.getResult (), linearizedMemref ,
485
+ op, bitCast.getResult (), memrefBase ,
479
486
getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices));
480
487
return success ();
481
488
}
@@ -511,7 +518,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
511
518
extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
512
519
frontSubWidthStoreElem, *foldedNumFrontPadElems);
513
520
514
- atomicStore (rewriter, loc, linearizedMemref , currentDestIndex,
521
+ atomicStore (rewriter, loc, memrefBase , currentDestIndex,
515
522
cast<VectorValue>(value), frontMask.getResult ());
516
523
}
517
524
@@ -537,13 +544,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
537
544
numNonFullWidthElements);
538
545
539
546
auto originType = cast<VectorType>(fullWidthStorePart.getType ());
540
- auto memrefElemType = getElementTypeOrSelf (linearizedMemref .getType ());
547
+ auto memrefElemType = getElementTypeOrSelf (memrefBase .getType ());
541
548
auto storeType = VectorType::get (
542
549
{originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
543
550
auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
544
551
fullWidthStorePart);
545
- rewriter.create <vector::StoreOp>(loc, bitCast.getResult (),
546
- linearizedMemref, currentDestIndex);
552
+ rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memrefBase,
553
+ currentDestIndex);
547
554
548
555
currentSourceIndex += numNonFullWidthElements;
549
556
currentDestIndex = rewriter.create <arith::AddIOp>(
@@ -565,7 +572,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
565
572
auto backMask = rewriter.create <arith::ConstantOp>(
566
573
loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
567
574
568
- atomicStore (rewriter, loc, linearizedMemref , currentDestIndex,
575
+ atomicStore (rewriter, loc, memrefBase , currentDestIndex,
569
576
cast<VectorValue>(subWidthStorePart), backMask.getResult ());
570
577
}
571
578
0 commit comments