-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add narrow type emulation conversions #72181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#include "mlir/Dialect/MemRef/Transforms/Transforms.h" | ||
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/IR/OpDefinition.h" | ||
#include "mlir/Support/MathExtras.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "llvm/Support/FormatVariadic.h" | ||
|
@@ -35,36 +36,98 @@ using namespace mlir; | |
/// Return the bit offset of the value at position `srcIdx`. For example, if | ||
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is | ||
/// located at (x % 2) * 4. Because there are two elements in one i8, and one | ||
/// element has 4 bits. | ||
/// element has 4 bits. If `rightOffset` is true, return the offset from the | ||
/// right side of the `dstBits` container instead of the left side. | ||
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, | ||
int sourceBits, int targetBits, | ||
OpBuilder &builder) { | ||
OpBuilder &builder, | ||
bool rightOffset = false) { | ||
assert(targetBits % sourceBits == 0); | ||
AffineExpr s0; | ||
bindSymbols(builder.getContext(), s0); | ||
int scaleFactor = targetBits / sourceBits; | ||
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply( | ||
builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx}); | ||
AffineExpr offsetExpr = | ||
rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits | ||
: (s0 % scaleFactor) * sourceBits; | ||
OpFoldResult offsetVal = | ||
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx}); | ||
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal); | ||
IntegerType dstType = builder.getIntegerType(targetBits); | ||
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset); | ||
} | ||
|
||
/// When writing a subbyte size, writing needs to happen atomically in case of | ||
/// another write happening on the same byte at the same time. To do the write, | ||
/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte | ||
/// store. This function returns the appropriate mask for clearing these bits. | ||
static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices, | ||
int64_t srcBits, int64_t dstBits, | ||
Value bitwidthOffset, OpBuilder &builder) { | ||
auto dstIntegerType = builder.getIntegerType(dstBits); | ||
auto maskRightAlignedAttr = | ||
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); | ||
Value maskRightAligned = | ||
builder | ||
.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr) | ||
.getResult(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, |
||
Value writeMaskInverse = | ||
builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset); | ||
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); | ||
Value flipVal = | ||
builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr) | ||
.getResult(); | ||
return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal); | ||
} | ||
|
||
/// Returns the scaled linearized index based on the `srcBits` and `dstBits` | ||
/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and | ||
/// the returned index has the granularity of `dstBits` | ||
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc, | ||
OpFoldResult linearizedIndex, | ||
int64_t srcBits, int64_t dstBits) { | ||
AffineExpr s0; | ||
bindSymbols(builder.getContext(), s0); | ||
int64_t scaler = dstBits / srcBits; | ||
OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply( | ||
builder, loc, s0.floorDiv(scaler), {linearizedIndex}); | ||
return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices); | ||
} | ||
|
||
static OpFoldResult | ||
getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, | ||
const SmallVector<OpFoldResult> &indices, | ||
Value memref) { | ||
auto stridedMetadata = | ||
builder.create<memref::ExtractStridedMetadataOp>(loc, memref); | ||
OpFoldResult linearizedIndices; | ||
std::tie(std::ignore, linearizedIndices) = | ||
memref::getLinearizedMemRefOffsetAndSize( | ||
builder, loc, srcBits, srcBits, | ||
stridedMetadata.getConstifiedMixedOffset(), | ||
stridedMetadata.getConstifiedMixedSizes(), | ||
stridedMetadata.getConstifiedMixedStrides(), indices); | ||
return linearizedIndices; | ||
} | ||
|
||
namespace { | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ConvertMemRefAlloc | ||
//===----------------------------------------------------------------------===// | ||
|
||
struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
template <typename OpTy> | ||
struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps we can rename it to |
||
using OpConversionPattern<OpTy>::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor, | ||
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto currentType = op.getMemref().getType().cast<MemRefType>(); | ||
auto newResultType = | ||
getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>(); | ||
static_assert(std::is_same<OpTy, memref::AllocOp>() || | ||
std::is_same<OpTy, memref::AllocaOp>(), | ||
"expected only memref::AllocOp or memref::AllocaOp"); | ||
auto currentType = cast<MemRefType>(op.getMemref().getType()); | ||
auto newResultType = dyn_cast<MemRefType>( | ||
this->getTypeConverter()->convertType(op.getType())); | ||
if (!newResultType) { | ||
return rewriter.notifyMatchFailure( | ||
op->getLoc(), | ||
|
@@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { | |
|
||
// Special case zero-rank memrefs. | ||
if (currentType.getRank() == 0) { | ||
rewriter.replaceOpWithNewOp<memref::AllocOp>( | ||
op, newResultType, ValueRange{}, adaptor.getSymbolOperands(), | ||
adaptor.getAlignmentAttr()); | ||
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{}, | ||
adaptor.getSymbolOperands(), | ||
adaptor.getAlignmentAttr()); | ||
return success(); | ||
} | ||
|
||
|
@@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { | |
rewriter, loc, linearizedMemRefInfo.linearizedSize)); | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<memref::AllocOp>( | ||
op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(), | ||
adaptor.getAlignmentAttr()); | ||
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize, | ||
adaptor.getSymbolOperands(), | ||
adaptor.getAlignmentAttr()); | ||
return success(); | ||
} | ||
}; | ||
|
@@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { | |
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(), | ||
ValueRange{}); | ||
} else { | ||
SmallVector<OpFoldResult> indices = | ||
getAsOpFoldResult(adaptor.getIndices()); | ||
|
||
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>( | ||
loc, op.getMemRef()); | ||
|
||
// Linearize the indices of the original load instruction. Do not account | ||
// for the scaling yet. This will be accounted for later. | ||
OpFoldResult linearizedIndices; | ||
std::tie(std::ignore, linearizedIndices) = | ||
memref::getLinearizedMemRefOffsetAndSize( | ||
rewriter, loc, srcBits, srcBits, | ||
stridedMetadata.getConstifiedMixedOffset(), | ||
stridedMetadata.getConstifiedMixedSizes(), | ||
stridedMetadata.getConstifiedMixedStrides(), indices); | ||
|
||
AffineExpr s0; | ||
bindSymbols(rewriter.getContext(), s0); | ||
int64_t scaler = dstBits / srcBits; | ||
OpFoldResult scaledLinearizedIndices = | ||
affine::makeComposedFoldedAffineApply( | ||
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices}); | ||
OpFoldResult linearizedIndices = getLinearizedSrcIndices( | ||
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); | ||
|
||
Value newLoad = rewriter.create<memref::LoadOp>( | ||
loc, adaptor.getMemref(), | ||
getValueOrCreateConstantIndexOp(rewriter, loc, | ||
scaledLinearizedIndices)); | ||
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, | ||
dstBits)); | ||
|
||
// Get the offset and shift the bits to the rightmost. | ||
// Note, currently only the big-endian is supported. | ||
|
@@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { | |
} | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ConvertMemRefReinterpretCast | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Currently there is very limited support for memref::ReinterpretCastOp | ||
/// conversion. Only the 0 dimensional case is supported. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why only 0D is supported? Are we able to emulate other cases? |
||
struct ConvertMemRefReinterpretCast final | ||
: OpConversionPattern<memref::ReinterpretCastOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
MemRefType newTy = | ||
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType())); | ||
if (!newTy) { | ||
return rewriter.notifyMatchFailure( | ||
op->getLoc(), | ||
llvm::formatv("failed to convert memref type: {0}", op.getType())); | ||
} | ||
|
||
auto convertedElementType = newTy.getElementType(); | ||
auto oldElementType = op.getType().getElementType(); | ||
int srcBits = oldElementType.getIntOrFloatBitWidth(); | ||
int dstBits = convertedElementType.getIntOrFloatBitWidth(); | ||
if (dstBits % srcBits != 0) { | ||
return rewriter.notifyMatchFailure( | ||
op, "only dstBits % srcBits == 0 supported"); | ||
} | ||
|
||
// Only support offset for 0-D subview. | ||
if (op.getType().getRank() != 0) { | ||
return rewriter.notifyMatchFailure( | ||
op->getLoc(), "subview with rank > 0 is not supported"); | ||
} | ||
|
||
int64_t offset = op.getStaticOffset(0); | ||
// Only support static sizes and offsets. | ||
if (offset == ShapedType::kDynamic) { | ||
return rewriter.notifyMatchFailure( | ||
op->getLoc(), "subview with dynamic offset is not supported"); | ||
} | ||
|
||
int elementsPerByte = dstBits / srcBits; | ||
if (offset % elementsPerByte != 0) { | ||
return rewriter.notifyMatchFailure( | ||
op->getLoc(), | ||
"subview with offset not multiple of elementsPerByte is not " | ||
"supported"); | ||
} | ||
|
||
offset = offset / elementsPerByte; | ||
|
||
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( | ||
op, newTy, *adaptor.getODSOperands(0).begin(), offset, | ||
SmallVector<int64_t>{}, op.getStaticStrides()); | ||
return success(); | ||
} | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ConvertMemrefStore | ||
//===----------------------------------------------------------------------===// | ||
|
||
struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LogicalResult | ||
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>(); | ||
auto convertedElementType = convertedType.getElementType(); | ||
auto oldElementType = op.getMemRefType().getElementType(); | ||
int srcBits = oldElementType.getIntOrFloatBitWidth(); | ||
int dstBits = convertedElementType.getIntOrFloatBitWidth(); | ||
auto dstIntegerType = rewriter.getIntegerType(dstBits); | ||
if (dstBits % srcBits != 0) { | ||
return rewriter.notifyMatchFailure( | ||
op, "only dstBits % srcBits == 0 supported"); | ||
} | ||
|
||
Location loc = op.getLoc(); | ||
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, | ||
adaptor.getValue()); | ||
|
||
// Special case 0-rank memref stores. We can compute the mask at compile | ||
// time. | ||
if (convertedType.getRank() == 0) { | ||
Comment on lines
+345
to
+347
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code is becoming larger than I expect.. Let's create two static functions. One for 0D case, and the other for non-0D case. What do you think? |
||
// Shift extended value to be left aligned | ||
auto shiftValAttr = | ||
rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits); | ||
Value shiftVal = | ||
rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr) | ||
.getResult(); | ||
Value alignedVal = | ||
rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal) | ||
.getResult(); | ||
// Create mask to clear destination bits | ||
auto writeMaskValAttr = rewriter.getIntegerAttr( | ||
dstIntegerType, (1 << (dstBits - srcBits)) - 1); | ||
Value writeMask = | ||
rewriter | ||
.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr) | ||
.getResult(); | ||
|
||
// Clear destination bits | ||
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, | ||
writeMask, adaptor.getMemref(), | ||
ValueRange{}); | ||
// Write srcs bits to destination | ||
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, | ||
alignedVal, adaptor.getMemref(), | ||
ValueRange{}); | ||
rewriter.eraseOp(op); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use |
||
return success(); | ||
} | ||
|
||
OpFoldResult linearizedIndices = getLinearizedSrcIndices( | ||
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); | ||
Value storeIndices = getIndicesForLoadOrStore( | ||
rewriter, loc, linearizedIndices, srcBits, dstBits); | ||
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, | ||
dstBits, rewriter, true); | ||
Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits, | ||
dstBits, bitwidthOffset, rewriter); | ||
// Align the value to write with the destination bits | ||
Value alignedVal = | ||
rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset) | ||
.getResult(); | ||
|
||
// Clear destination bits | ||
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, | ||
writeMask, adaptor.getMemref(), | ||
storeIndices); | ||
// Write srcs bits to destination | ||
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, | ||
alignedVal, adaptor.getMemref(), | ||
storeIndices); | ||
|
||
rewriter.eraseOp(op); | ||
return success(); | ||
} | ||
}; | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ConvertMemRefSubview | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -291,8 +481,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( | |
RewritePatternSet &patterns) { | ||
|
||
// Populate `memref.*` conversion patterns. | ||
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad, | ||
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>( | ||
patterns.add<ConvertMemRefAlloc<memref::AllocOp>, | ||
ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad, | ||
ConvertMemRefAssumeAlignment, ConvertMemRefSubview, | ||
ConvertMemrefStore, ConvertMemRefReinterpretCast>( | ||
typeConverter, patterns.getContext()); | ||
memref::populateResolveExtractStridedMetadataPatterns(patterns); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing. I'd rather have two methods
getLeftOffset...
andgetRightOffset...
(also maybe its worth finding something better than left and right)