Skip to content

Commit 974c2cb

Browse files
committed
Address comments
1 parent 70de874 commit 974c2cb

File tree

1 file changed

+21
-14
lines changed

1 file changed

+21
-14
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -305,12 +305,14 @@ static Value downcastSelectAndUpcast(OpBuilder &builder, Location loc,
305305
assert(
306306
downcastType.getNumElements() * downcastType.getElementTypeBitWidth() ==
307307
upcastType.getNumElements() * upcastType.getElementTypeBitWidth() &&
308-
"expected upcastType size to be twice the size of downcastType");
309-
if (trueValue.getType() != downcastType)
308+
"expected input and output number of bits to match");
309+
if (trueValue.getType() != downcastType) {
310310
trueValue = builder.create<vector::BitCastOp>(loc, downcastType, trueValue);
311-
if (falseValue.getType() != downcastType)
311+
}
312+
if (falseValue.getType() != downcastType) {
312313
falseValue =
313314
builder.create<vector::BitCastOp>(loc, downcastType, falseValue);
315+
}
314316
Value selectedType =
315317
builder.create<arith::SelectOp>(loc, mask, trueValue, falseValue);
316318
// Upcast the selected value to the new type.
@@ -454,28 +456,33 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
454456
stridedMetadata.getConstifiedMixedStrides(),
455457
getAsOpFoldResult(adaptor.getIndices()));
456458

457-
auto foldedNumFrontPadElems =
459+
std::optional<int64_t> foldedNumFrontPadElems =
458460
isUnalignedEmulation
459461
? getConstantIntValue(linearizedInfo.intraDataOffset)
460462
: 0;
461463

462464
if (!foldedNumFrontPadElems) {
463-
// Unimplemented case for dynamic front padding size != 0
464-
return failure();
465+
return failure("subbyte store emulation: dynamic front padding size is "
466+
"not yet implemented");
465467
}
466468

467-
auto linearizedMemref = cast<MemRefValue>(adaptor.getBase());
469+
auto memrefBase = cast<MemRefValue>(adaptor.getBase());
468470

469-
// Shortcut: conditions when subbyte store at the front is not needed:
471+
// Shortcut: conditions when subbyte emulated store at the front is not
472+
// needed:
470473
// 1. The source vector size is multiple of byte size
471474
// 2. The address of the store is aligned to the emulated width boundary
475+
//
476+
// For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
477+
// need unaligned emulation because the store address is aligned and the
478+
// source is a whole byte.
472479
if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0) {
473480
auto numElements = origElements / numSrcElemsPerDest;
474481
auto bitCast = rewriter.create<vector::BitCastOp>(
475482
loc, VectorType::get(numElements, newElementType),
476483
op.getValueToStore());
477484
rewriter.replaceOpWithNewOp<vector::StoreOp>(
478-
op, bitCast.getResult(), linearizedMemref,
485+
op, bitCast.getResult(), memrefBase,
479486
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
480487
return success();
481488
}
@@ -511,7 +518,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
511518
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
512519
frontSubWidthStoreElem, *foldedNumFrontPadElems);
513520

514-
atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
521+
atomicStore(rewriter, loc, memrefBase, currentDestIndex,
515522
cast<VectorValue>(value), frontMask.getResult());
516523
}
517524

@@ -537,13 +544,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
537544
numNonFullWidthElements);
538545

539546
auto originType = cast<VectorType>(fullWidthStorePart.getType());
540-
auto memrefElemType = getElementTypeOrSelf(linearizedMemref.getType());
547+
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
541548
auto storeType = VectorType::get(
542549
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
543550
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
544551
fullWidthStorePart);
545-
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(),
546-
linearizedMemref, currentDestIndex);
552+
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
553+
currentDestIndex);
547554

548555
currentSourceIndex += numNonFullWidthElements;
549556
currentDestIndex = rewriter.create<arith::AddIOp>(
@@ -565,7 +572,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
565572
auto backMask = rewriter.create<arith::ConstantOp>(
566573
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
567574

568-
atomicStore(rewriter, loc, linearizedMemref, currentDestIndex,
575+
atomicStore(rewriter, loc, memrefBase, currentDestIndex,
569576
cast<VectorValue>(subWidthStorePart), backMask.getResult());
570577
}
571578

0 commit comments

Comments
 (0)