Skip to content

[mlir] Add narrow type emulation for memref.reinterpret_cast #73144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2023

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Nov 22, 2023

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Nov 22, 2023

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: None (Max191)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/73144.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+95-41)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..dec5936fa7e83ce 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,11 +17,14 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
+#include <type_traits>
 
 using namespace mlir;
 
@@ -29,6 +32,62 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
+/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
+/// type. The result MemRefType of the old op must have a rank and stride of 1,
+/// with static offset and size. The number of bits in the offset must evenly
+/// divide the bitwidth of the new converted type.
+template <typename MemRefOpTy>
+static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
+                                      typename MemRefOpTy::Adaptor adaptor,
+                                      MemRefOpTy op, MemRefType newTy) {
+  static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
+                    std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
+                "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
+
+  auto convertedElementType = newTy.getElementType();
+  auto oldElementType = op.getType().getElementType();
+  int srcBits = oldElementType.getIntOrFloatBitWidth();
+  int dstBits = convertedElementType.getIntOrFloatBitWidth();
+  if (dstBits % srcBits != 0) {
+    return rewriter.notifyMatchFailure(op,
+                                       "only dstBits % srcBits == 0 supported");
+  }
+
+  // Only support stride of 1.
+  if (llvm::any_of(op.getStaticStrides(),
+                   [](int64_t stride) { return stride != 1; })) {
+    return rewriter.notifyMatchFailure(op->getLoc(),
+                                       "stride != 1 is not supported");
+  }
+
+  auto sizes = op.getStaticSizes();
+  int64_t offset = op.getStaticOffset(0);
+  // Only support static sizes and offsets.
+  if (llvm::any_of(sizes,
+                   [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+      offset == ShapedType::kDynamic) {
+    return rewriter.notifyMatchFailure(
+        op->getLoc(), "dynamic size or offset is not supported");
+  }
+
+  int elementsPerByte = dstBits / srcBits;
+  if (offset % elementsPerByte != 0) {
+    return rewriter.notifyMatchFailure(
+        op->getLoc(), "offset not multiple of elementsPerByte is not "
+                      "supported");
+  }
+
+  SmallVector<int64_t> size;
+  if (sizes.size())
+    size.push_back(ceilDiv(sizes[0], elementsPerByte));
+  offset = offset / elementsPerByte;
+
+  rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
+                                          *adaptor.getODSOperands(0).begin(),
+                                          offset, size, op.getStaticStrides());
+  return success();
+}
+
 /// When data is loaded/stored in `targetBits` granularity, but is used in
 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
 /// treated as an array of elements of width `sourceBits`.
@@ -211,6 +270,37 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Output types should be at most one dimensional, so only the 0 or 1
+/// dimensional cases are supported.
+struct ConvertMemRefReinterpretCast final
+    : OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType newTy =
+        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    // Only support for 0 or 1 dimensional cases.
+    if (op.getType().getRank() > 1) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with rank > 1 is not supported");
+    }
+
+    return convertCastingOp(rewriter, adaptor, op, newTy);
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -233,50 +323,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
           llvm::formatv("failed to convert memref type: {0}", op.getType()));
     }
 
-    auto convertedElementType = newTy.getElementType();
-    auto oldElementType = op.getType().getElementType();
-    int srcBits = oldElementType.getIntOrFloatBitWidth();
-    int dstBits = convertedElementType.getIntOrFloatBitWidth();
-    if (dstBits % srcBits != 0) {
-      return rewriter.notifyMatchFailure(
-          op, "only dstBits % srcBits == 0 supported");
-    }
-
     // Only support offset for 1-D subview.
     if (op.getType().getRank() != 1) {
       return rewriter.notifyMatchFailure(
           op->getLoc(), "subview with rank > 1 is not supported");
     }
 
-    // Only support stride of 1.
-    if (op.getStaticStride(0) != 1) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with stride != 1 is not supported");
-    }
-
-    int64_t size = op.getStaticSize(0);
-    int64_t offset = op.getStaticOffset(0);
-    // Only support static sizes and offsets.
-    if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with dynamic size or offset is not supported");
-    }
-
-    int elementsPerByte = dstBits / srcBits;
-    if (offset % elementsPerByte != 0) {
-      return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          "subview with offset not multiple of elementsPerByte is not "
-          "supported");
-    }
-
-    size = ceilDiv(size, elementsPerByte);
-    offset = offset / elementsPerByte;
-
-    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
-        op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
-        op.getStaticStrides());
-    return success();
+    return convertCastingOp(rewriter, adaptor, op, newTy);
   }
 };
 
@@ -291,9 +344,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+           ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+          typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
 

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably need a lit test for the reinterpret_cast support.

@Max191
Copy link
Contributor Author

Max191 commented Nov 22, 2023

Probably need a lit test for the reinterpret_cast support.

Oh, I didn't realize somehow the test got lost in this branch. I'll add it back

@Max191 Max191 force-pushed the memref-reinterpret-cast-subbyte-emulation branch from f46ab4e to b72b7cd Compare November 27, 2023 17:45
@hanhanW hanhanW self-requested a review November 27, 2023 18:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants