-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add support for memref.alloca
sub-byte emulation
#73138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
94bace3
to
f517a95
Compare
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: None (Max191) ChangesAdds a similar case to Full diff: https://github.com/llvm/llvm-project/pull/73138.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..78c523f08ea301a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -53,18 +53,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
namespace {
//===----------------------------------------------------------------------===//
-// ConvertMemRefAlloc
+// ConvertMemRefAllocation
//===----------------------------------------------------------------------===//
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto currentType = op.getMemref().getType().cast<MemRefType>();
- auto newResultType =
- getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+ std::is_same<OpTy, memref::AllocaOp>(),
+ "expected only memref::AllocOp or memref::AllocaOp");
+ auto currentType = cast<MemRefType>(op.getMemref().getType());
+ auto newResultType = dyn_cast<MemRefType>(
+ this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -73,9 +77,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
@@ -97,9 +101,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
};
@@ -291,7 +295,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+ patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
+ ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..a25b0a668499a23 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,36 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloca() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
|
struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
template <typename OpTy> | ||
struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we can do this using an AllocLikeOpInterface
? Its fine to do as a follow up as well (or leave a TODO saying this could be generalized using the interface).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please help fix the conflicts.
f517a95
to
7c1a7f7
Compare
Adds a similar case to
memref.alloc
formemref.alloca
in EmulateNarrowTypes.Fixes iree-org/iree#15515