-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add narrow type emulation conversions #72181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: None (Max191) ChangesAdds narrow type emulation support for: Patch is 30.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72181.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..078df55e351db96 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
@@ -35,36 +36,98 @@ using namespace mlir;
/// Return the bit offset of the value at position `srcIdx`. For example, if
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
-/// element has 4 bits.
+/// element has 4 bits. If `rightOffset` is true, return the offset from the
+/// right side of the `dstBits` container instead of the left side.
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
int sourceBits, int targetBits,
- OpBuilder &builder) {
+ OpBuilder &builder,
+ bool rightOffset = false) {
assert(targetBits % sourceBits == 0);
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int scaleFactor = targetBits / sourceBits;
- OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
- builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+ AffineExpr offsetExpr =
+ rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
+ : (s0 % scaleFactor) * sourceBits;
+ OpFoldResult offsetVal =
+ affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+ int64_t srcBits, int64_t dstBits,
+ Value bitwidthOffset, OpBuilder &builder) {
+ auto dstIntegerType = builder.getIntegerType(dstBits);
+ auto maskRightAlignedAttr =
+ builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+ Value maskRightAligned =
+ builder
+ .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+ .getResult();
+ Value writeMaskInverse =
+ builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+ Value flipVal =
+ builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+ .getResult();
+ return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+ OpFoldResult linearizedIndex,
+ int64_t srcBits, int64_t dstBits) {
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
+ int64_t scaler = dstBits / srcBits;
+ OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+ return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+ const SmallVector<OpFoldResult> &indices,
+ Value memref) {
+ auto stridedMetadata =
+ builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ builder, loc, srcBits, srcBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+ return linearizedIndices;
+}
+
namespace {
//===----------------------------------------------------------------------===//
// ConvertMemRefAlloc
//===----------------------------------------------------------------------===//
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto currentType = op.getMemref().getType().cast<MemRefType>();
- auto newResultType =
- getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+ std::is_same<OpTy, memref::AllocaOp>(),
+ "expected only memref::AllocOp or memref::AllocaOp");
+ auto currentType = cast<MemRefType>(op.getMemref().getType());
+ auto newResultType = dyn_cast<MemRefType>(
+ this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
@@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
};
@@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
ValueRange{});
} else {
- SmallVector<OpFoldResult> indices =
- getAsOpFoldResult(adaptor.getIndices());
-
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, op.getMemRef());
-
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
- OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, srcBits,
- stridedMetadata.getConstifiedMixedOffset(),
- stridedMetadata.getConstifiedMixedSizes(),
- stridedMetadata.getConstifiedMixedStrides(), indices);
-
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- int64_t scaler = dstBits / srcBits;
- OpFoldResult scaledLinearizedIndices =
- affine::makeComposedFoldedAffineApply(
- rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
Value newLoad = rewriter.create<memref::LoadOp>(
loc, adaptor.getMemref(),
- getValueOrCreateConstantIndexOp(rewriter, loc,
- scaledLinearizedIndices));
+ getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+ dstBits));
// Get the offset and shift the bits to the rightmost.
// Note, currently only the big-endian is supported.
@@ -211,6 +257,136 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+///
+struct ConvertMemRefReinterpretCast final
+ : OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType newTy =
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support offset for 0-D subview.
+ if (op.getType().getRank() != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 0 is not supported");
+ }
+
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with dynamic offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ "subview with offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+ op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+ SmallVector<int64_t>{}, op.getStaticStrides());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ auto convertedElementType = convertedType.getElementType();
+ auto oldElementType = op.getMemRefType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, adaptor.getValue());
+
+ // Special case 0-rank memref stores. We can compute the mask at compile
+ // time.
+ if (convertedType.getRank() == 0) {
+ // Shift extended value to be left aligned
+ auto shiftValAttr = rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+ Value shiftVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr).getResult();
+ Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal).getResult();
+ // Create mask to clear destination bits
+ auto writeMaskValAttr = rewriter.getIntegerAttr(
+ dstIntegerType, (1 << (dstBits - srcBits)) - 1);
+ Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr).getResult();
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), ValueRange{});
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), ValueRange{});
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+ Value storeIndices = getIndicesForLoadOrStore(
+ rewriter, loc, linearizedIndices, srcBits, dstBits);
+ Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+ dstBits, rewriter, true);
+ Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
+ dstBits, bitwidthOffset, rewriter);
+ // Align the value to write with the destination bits
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
+ .getResult();
+
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+ writeMask, adaptor.getMemref(),
+ storeIndices);
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+ alignedVal, adaptor.getMemref(),
+ storeIndices);
+
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -291,8 +467,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+ patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+ ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
+ ConvertMemrefStore, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..05ec5761c8fe024 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,231 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloca() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+ %0 = memref.alloc() : memref<5xi4>
+ memref.store %arg1, %0[%arg0] : memref<5xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 2) * 8 + 4)>
+// CHECK: func @memref_store_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 8) * 32 + 28)>
+// CHECK32: func @memref_store_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
9901854
to
bbc87f8
Compare
Adds narrow type emulation support for: - `memref.alloca` - `memref.store` - `memref.reinterpret_cast`
bbc87f8
to
1b4a88b
Compare
Could we split the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we split it to three PRs?
- memref.store
- memref.reinterpret_cast
- memref.alloca
struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
template <typename OpTy> | ||
struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps we can rename it to ConvertMemRefAllocation
. It is not only used by memref.alloc, but also memref.alloca.
//===----------------------------------------------------------------------===// | ||
|
||
/// Currently there is very limited support for memref::ReinterpretCastOp | ||
/// conversion. Only the 0 dimensional case is supported. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only 0D is supported? Are we able to emulate other cases?
Value maskRightAligned = | ||
builder | ||
.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr) | ||
.getResult(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIRC, .getResult()
is not needed when Value
type is specified. same for below other codes.
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, | ||
alignedVal, adaptor.getMemref(), | ||
ValueRange{}); | ||
rewriter.eraseOp(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use replaceOp
instead? That's more common in pattern-rewrite.
// Special case 0-rank memref stores. We can compute the mask at compile | ||
// time. | ||
if (convertedType.getRank() == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code is becoming larger than I expect.. Let's create two static functions. One for 0D case, and the other for non-0D case. What do you think?
@@ -35,36 +36,98 @@ using namespace mlir; | |||
/// Return the bit offset of the value at position `srcIdx`. For example, if | |||
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is | |||
/// located at (x % 2) * 4. Because there are two elements in one i8, and one | |||
/// element has 4 bits. | |||
/// element has 4 bits. If `rightOffset` is true, return the offset from the | |||
/// right side of the `dstBits` container instead of the left side. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is confusing. I'd rather have two methods getLeftOffset...
and getRightOffset...
(also maybe its worth finding something better than left and right)
Adds narrow type emulation support for:
-
memref.alloca
-
memref.store
-
memref.reinterpret_cast
Fixes iree-org/iree#15370