Skip to content

[mlir] Add narrow type emulation conversions #72181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 232 additions & 40 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
Expand All @@ -35,36 +36,98 @@ using namespace mlir;
/// Return the bit offset of the value at position `srcIdx`. For example, if
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
/// element has 4 bits.
/// element has 4 bits. If `rightOffset` is true, return the offset from the
/// right side of the `dstBits` container instead of the left side.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing. I'd rather have two methods getLeftOffset... and getRightOffset... (also maybe its worth finding something better than left and right)

static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
int sourceBits, int targetBits,
OpBuilder &builder) {
OpBuilder &builder,
bool rightOffset = false) {
assert(targetBits % sourceBits == 0);
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int scaleFactor = targetBits / sourceBits;
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
AffineExpr offsetExpr =
rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
: (s0 % scaleFactor) * sourceBits;
OpFoldResult offsetVal =
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}

/// When writing a subbyte size, writing needs to happen atomically in case of
/// another write happening on the same byte at the same time. To do the write,
/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
/// store. This function returns the appropriate mask for clearing these bits.
static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
int64_t srcBits, int64_t dstBits,
Value bitwidthOffset, OpBuilder &builder) {
auto dstIntegerType = builder.getIntegerType(dstBits);
auto maskRightAlignedAttr =
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
Value maskRightAligned =
builder
.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
.getResult();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, .getResult() is not needed when Value type is specified. same for below other codes.

Value writeMaskInverse =
builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
Value flipVal =
builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
.getResult();
return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
}

/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
/// the returned index has the granularity of `dstBits`
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
OpFoldResult linearizedIndex,
int64_t srcBits, int64_t dstBits) {
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int64_t scaler = dstBits / srcBits;
OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
builder, loc, s0.floorDiv(scaler), {linearizedIndex});
return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
}

static OpFoldResult
getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
const SmallVector<OpFoldResult> &indices,
Value memref) {
auto stridedMetadata =
builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
builder, loc, srcBits, srcBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(), indices);
return linearizedIndices;
}

namespace {

//===----------------------------------------------------------------------===//
// ConvertMemRefAlloc
//===----------------------------------------------------------------------===//

struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
template <typename OpTy>
struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can rename it to ConvertMemRefAllocation. It is not only used by memref.alloc, but also memref.alloca.

using OpConversionPattern<OpTy>::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto currentType = op.getMemref().getType().cast<MemRefType>();
auto newResultType =
getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
static_assert(std::is_same<OpTy, memref::AllocOp>() ||
std::is_same<OpTy, memref::AllocaOp>(),
"expected only memref::AllocOp or memref::AllocaOp");
auto currentType = cast<MemRefType>(op.getMemref().getType());
auto newResultType = dyn_cast<MemRefType>(
this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
Expand All @@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {

// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
rewriter.replaceOpWithNewOp<memref::AllocOp>(
op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
return success();
}

Expand All @@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}

rewriter.replaceOpWithNewOp<memref::AllocOp>(
op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
adaptor.getSymbolOperands(),
adaptor.getAlignmentAttr());
return success();
}
};
Expand Down Expand Up @@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
ValueRange{});
} else {
SmallVector<OpFoldResult> indices =
getAsOpFoldResult(adaptor.getIndices());

auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
loc, op.getMemRef());

// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
OpFoldResult linearizedIndices;
std::tie(std::ignore, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, srcBits,
stridedMetadata.getConstifiedMixedOffset(),
stridedMetadata.getConstifiedMixedSizes(),
stridedMetadata.getConstifiedMixedStrides(), indices);

AffineExpr s0;
bindSymbols(rewriter.getContext(), s0);
int64_t scaler = dstBits / srcBits;
OpFoldResult scaledLinearizedIndices =
affine::makeComposedFoldedAffineApply(
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());

Value newLoad = rewriter.create<memref::LoadOp>(
loc, adaptor.getMemref(),
getValueOrCreateConstantIndexOp(rewriter, loc,
scaledLinearizedIndices));
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
dstBits));

// Get the offset and shift the bits to the rightmost.
// Note, currently only the big-endian is supported.
Expand Down Expand Up @@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefReinterpretCast
//===----------------------------------------------------------------------===//

/// Currently there is very limited support for memref::ReinterpretCastOp
/// conversion. Only the 0 dimensional case is supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only 0D is supported? Are we able to emulate other cases?

struct ConvertMemRefReinterpretCast final
: OpConversionPattern<memref::ReinterpretCastOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType newTy =
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}

auto convertedElementType = newTy.getElementType();
auto oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}

// Only support offset for 0-D subview.
if (op.getType().getRank() != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 0 is not supported");
}

int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with dynamic offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(),
"subview with offset not multiple of elementsPerByte is not "
"supported");
}

offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, newTy, *adaptor.getODSOperands(0).begin(), offset,
SmallVector<int64_t>{}, op.getStaticStrides());
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemrefStore
//===----------------------------------------------------------------------===//

struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
auto convertedElementType = convertedType.getElementType();
auto oldElementType = op.getMemRefType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
auto dstIntegerType = rewriter.getIntegerType(dstBits);
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}

Location loc = op.getLoc();
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
adaptor.getValue());

// Special case 0-rank memref stores. We can compute the mask at compile
// time.
if (convertedType.getRank() == 0) {
Comment on lines +345 to +347
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is becoming larger than I expect.. Let's create two static functions. One for 0D case, and the other for non-0D case. What do you think?

// Shift extended value to be left aligned
auto shiftValAttr =
rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
Value shiftVal =
rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
.getResult();
Value alignedVal =
rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal)
.getResult();
// Create mask to clear destination bits
auto writeMaskValAttr = rewriter.getIntegerAttr(
dstIntegerType, (1 << (dstBits - srcBits)) - 1);
Value writeMask =
rewriter
.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
.getResult();

// Clear destination bits
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
writeMask, adaptor.getMemref(),
ValueRange{});
// Write srcs bits to destination
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(),
ValueRange{});
rewriter.eraseOp(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use replaceOp instead? That's more common in pattern-rewrite.

return success();
}

OpFoldResult linearizedIndices = getLinearizedSrcIndices(
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
Value storeIndices = getIndicesForLoadOrStore(
rewriter, loc, linearizedIndices, srcBits, dstBits);
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
dstBits, rewriter, true);
Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
dstBits, bitwidthOffset, rewriter);
// Align the value to write with the destination bits
Value alignedVal =
rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
.getResult();

// Clear destination bits
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
writeMask, adaptor.getMemref(),
storeIndices);
// Write srcs bits to destination
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(),
storeIndices);

rewriter.eraseOp(op);
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -291,8 +481,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {

// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
ConvertMemrefStore, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
Expand Down
Loading