Skip to content

[mlir][Vector] Update VectorEmulateNarrowType.cpp (1/N) #123526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 66 additions & 54 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -415,18 +415,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();

auto valueToStore = cast<VectorValue>(op.getValueToStore());
auto oldElementType = valueToStore.getType().getElementType();
auto newElementType =
auto containerElemTy =
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
Type emulatedElemTy = op.getValueToStore().getType().getElementType();
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
int containerBits = containerElemTy.getIntOrFloatBitWidth();

if (dstBits % srcBits != 0) {
// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int numSrcElemsPerDest = dstBits / srcBits;
int numSrcElemsPerDest = containerBits / emulatedBits;

// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
Expand All @@ -451,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
Expand Down Expand Up @@ -483,7 +486,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// Basic case: storing full bytes.
auto numElements = origElements / numSrcElemsPerDest;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, newElementType),
loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
rewriter.replaceOpWithNewOp<vector::StoreOp>(
op, bitCast.getResult(), memrefBase,
Expand Down Expand Up @@ -638,18 +641,20 @@ struct ConvertVectorMaskedStore final
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getValueToStore().getType().getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
auto containerElemTy =
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
Type emulatedElemTy = op.getValueToStore().getType().getElementType();
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
int containerBits = containerElemTy.getIntOrFloatBitWidth();

if (dstBits % srcBits != 0) {
// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}

int scale = dstBits / srcBits;
int scale = containerBits / emulatedBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
return failure();
Expand All @@ -660,7 +665,7 @@ struct ConvertVectorMaskedStore final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndicesOfr) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
Expand Down Expand Up @@ -706,15 +711,15 @@ struct ConvertVectorMaskedStore final
return failure();

auto numElements = (origElements + scale - 1) / scale;
auto newType = VectorType::get(numElements, newElementType);
auto newType = VectorType::get(numElements, containerElemTy);
auto passThru = rewriter.create<arith::ConstantOp>(
loc, newType, rewriter.getZeroAttr(newType));

auto newLoad = rewriter.create<vector::MaskedLoadOp>(
loc, newType, adaptor.getBase(), linearizedIndices,
newMask.value()->getResult(0), passThru);

auto newBitCastType = VectorType::get(numElements * scale, oldElementType);
auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
Value valueToStore =
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
valueToStore = rewriter.create<arith::SelectOp>(
Expand Down Expand Up @@ -746,17 +751,19 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();
auto containerElemTy =
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
Type emulatedElemTy = op.getType().getElementType();
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
int containerBits = containerElemTy.getIntOrFloatBitWidth();

if (dstBits % srcBits != 0) {
// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int scale = dstBits / srcBits;
int scale = containerBits / emulatedBits;

// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
Expand Down Expand Up @@ -797,7 +804,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
Expand All @@ -814,7 +821,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
llvm::divideCeil(maxintraDataOffset + origElements, scale);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, oldElementType, newElementType);
numElements, emulatedElemTy, containerElemTy);

if (!foldedIntraVectorOffset) {
auto resultVector = rewriter.create<arith::ConstantOp>(
Expand Down Expand Up @@ -848,17 +855,20 @@ struct ConvertVectorMaskedLoad final
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();

if (dstBits % srcBits != 0) {
auto containerElemTy =
cast<MemRefType>(adaptor.getBase().getType()).getElementType();
Type emulatedElemTy = op.getType().getElementType();
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
int containerBits = containerElemTy.getIntOrFloatBitWidth();

// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int scale = dstBits / srcBits;
int scale = containerBits / emulatedBits;

// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
Expand Down Expand Up @@ -912,7 +922,7 @@ struct ConvertVectorMaskedLoad final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
Expand All @@ -933,8 +943,8 @@ struct ConvertVectorMaskedLoad final

auto numElements =
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto loadType = VectorType::get(numElements, newElementType);
auto newBitcastType = VectorType::get(numElements * scale, oldElementType);
auto loadType = VectorType::get(numElements, containerElemTy);
auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);

auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
Expand Down Expand Up @@ -1009,23 +1019,25 @@ struct ConvertVectorTransferRead final
"only 1-D vectors are supported ATM");

auto loc = op.getLoc();
auto convertedType = cast<MemRefType>(adaptor.getSource().getType());
Type oldElementType = op.getType().getElementType();
Type newElementType = convertedType.getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = newElementType.getIntOrFloatBitWidth();

if (dstBits % srcBits != 0) {
auto containerElemTy =
cast<MemRefType>(adaptor.getSource().getType()).getElementType();
Type emulatedElemTy = op.getType().getElementType();
int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth();
int containerBits = containerElemTy.getIntOrFloatBitWidth();

// Check per-element alignment.
if (containerBits % emulatedBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int scale = dstBits / srcBits;
int scale = containerBits / emulatedBits;

auto origElements = op.getVectorType().getNumElements();

bool isAlignedEmulation = origElements % scale == 0;

auto newPadding = rewriter.create<arith::ExtUIOp>(loc, newElementType,
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
adaptor.getPadding());

auto stridedMetadata =
Expand All @@ -1035,7 +1047,7 @@ struct ConvertVectorTransferRead final
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
rewriter, loc, emulatedBits, containerBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(),
Expand All @@ -1051,12 +1063,12 @@ struct ConvertVectorTransferRead final
llvm::divideCeil(maxIntraDataOffset + origElements, scale);

auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newPadding);

auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements * scale, oldElementType), newRead);
loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);

Value result = bitCast->getResult(0);
if (!foldedIntraVectorOffset) {
Expand Down