Skip to content

[mlir] Add support for memref.alloca sub-byte emulation #73138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 28, 2023

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Nov 22, 2023

Adds a similar case to memref.alloc for memref.alloca in EmulateNarrowTypes.

Fixes iree-org/iree#15515

@Max191 Max191 force-pushed the memref-alloca-subbyte branch from 94bace3 to f517a95 Compare November 22, 2023 15:38
@Max191 Max191 marked this pull request as ready for review November 22, 2023 15:38
@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2023

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

Adds a similar case to memref.alloc for memref.alloca in EmulateNarrowTypes.


Full diff: https://github.com/llvm/llvm-project/pull/73138.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+19-14)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+33)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..78c523f08ea301a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -53,18 +53,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
 namespace {
 
 //===----------------------------------------------------------------------===//
-// ConvertMemRefAlloc
+// ConvertMemRefAllocation
 //===----------------------------------------------------------------------===//
 
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
+  using OpConversionPattern<OpTy>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto currentType = op.getMemref().getType().cast<MemRefType>();
-    auto newResultType =
-        getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+    static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+                      std::is_same<OpTy, memref::AllocaOp>(),
+                  "expected only memref::AllocOp or memref::AllocaOp");
+    auto currentType = cast<MemRefType>(op.getMemref().getType());
+    auto newResultType = dyn_cast<MemRefType>(
+        this->getTypeConverter()->convertType(op.getType()));
     if (!newResultType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -73,9 +77,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
 
     // Special case zero-rank memrefs.
     if (currentType.getRank() == 0) {
-      rewriter.replaceOpWithNewOp<memref::AllocOp>(
-          op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
-          adaptor.getAlignmentAttr());
+      rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+                                        adaptor.getSymbolOperands(),
+                                        adaptor.getAlignmentAttr());
       return success();
     }
 
@@ -97,9 +101,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
           rewriter, loc, linearizedMemRefInfo.linearizedSize));
     }
 
-    rewriter.replaceOpWithNewOp<memref::AllocOp>(
-        op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
-        adaptor.getAlignmentAttr());
+    rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+                                      adaptor.getSymbolOperands(),
+                                      adaptor.getAlignmentAttr());
     return success();
   }
 };
@@ -291,7 +295,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+  patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
+               ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
                ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
       typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..a25b0a668499a23 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,36 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 //       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
 //       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
 //       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+    %0 = memref.alloca() : memref<5xi4>
+    %1 = memref.load %0[%arg0] : memref<5xi4>
+    return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]

struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
template <typename OpTy>
struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can do this using an AllocLikeOpInterface ? Its fine to do as a follow up as well (or leave a TODO saying this could be generalized using the interface).

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please help fix the conflicts.

@Max191 Max191 force-pushed the memref-alloca-subbyte branch from f517a95 to 7c1a7f7 Compare November 27, 2023 20:24
@hanhanW hanhanW merged commit b823f84 into llvm:main Nov 28, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Missing sub-byte emulation support for memref.alloca
4 participants