-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add narrow type emulation for memref.reinterpret_cast
#73144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Add narrow type emulation for memref.reinterpret_cast
#73144
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: None (Max191) ChangesFull diff: https://github.com/llvm/llvm-project/pull/73144.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..dec5936fa7e83ce 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,11 +17,14 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
+#include <type_traits>
using namespace mlir;
@@ -29,6 +32,62 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
+/// type. The result MemRefType of the old op must have a rank and stride of 1,
+/// with static offset and size. The number of bits in the offset must evenly
+/// divide the bitwidth of the new converted type.
+template <typename MemRefOpTy>
+static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
+ typename MemRefOpTy::Adaptor adaptor,
+ MemRefOpTy op, MemRefType newTy) {
+ static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
+ std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
+ "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
+
+ auto convertedElementType = newTy.getElementType();
+ auto oldElementType = op.getType().getElementType();
+ int srcBits = oldElementType.getIntOrFloatBitWidth();
+ int dstBits = convertedElementType.getIntOrFloatBitWidth();
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(op,
+ "only dstBits % srcBits == 0 supported");
+ }
+
+ // Only support stride of 1.
+ if (llvm::any_of(op.getStaticStrides(),
+ [](int64_t stride) { return stride != 1; })) {
+ return rewriter.notifyMatchFailure(op->getLoc(),
+ "stride != 1 is not supported");
+ }
+
+ auto sizes = op.getStaticSizes();
+ int64_t offset = op.getStaticOffset(0);
+ // Only support static sizes and offsets.
+ if (llvm::any_of(sizes,
+ [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+ offset == ShapedType::kDynamic) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "dynamic size or offset is not supported");
+ }
+
+ int elementsPerByte = dstBits / srcBits;
+ if (offset % elementsPerByte != 0) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "offset not multiple of elementsPerByte is not "
+ "supported");
+ }
+
+ SmallVector<int64_t> size;
+ if (sizes.size())
+ size.push_back(ceilDiv(sizes[0], elementsPerByte));
+ offset = offset / elementsPerByte;
+
+ rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
+ *adaptor.getODSOperands(0).begin(),
+ offset, size, op.getStaticStrides());
+ return success();
+}
+
/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
@@ -211,6 +270,37 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+/// Output types should be at most one dimensional, so only the 0 or 1
+/// dimensional cases are supported.
+struct ConvertMemRefReinterpretCast final
+ : OpConversionPattern<memref::ReinterpretCastOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ MemRefType newTy =
+ dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+ if (!newTy) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(),
+ llvm::formatv("failed to convert memref type: {0}", op.getType()));
+ }
+
+ // Only support for 0 or 1 dimensional cases.
+ if (op.getType().getRank() > 1) {
+ return rewriter.notifyMatchFailure(
+ op->getLoc(), "subview with rank > 1 is not supported");
+ }
+
+ return convertCastingOp(rewriter, adaptor, op, newTy);
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -233,50 +323,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}
- auto convertedElementType = newTy.getElementType();
- auto oldElementType = op.getType().getElementType();
- int srcBits = oldElementType.getIntOrFloatBitWidth();
- int dstBits = convertedElementType.getIntOrFloatBitWidth();
- if (dstBits % srcBits != 0) {
- return rewriter.notifyMatchFailure(
- op, "only dstBits % srcBits == 0 supported");
- }
-
// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
}
- // Only support stride of 1.
- if (op.getStaticStride(0) != 1) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with stride != 1 is not supported");
- }
-
- int64_t size = op.getStaticSize(0);
- int64_t offset = op.getStaticOffset(0);
- // Only support static sizes and offsets.
- if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
- return rewriter.notifyMatchFailure(
- op->getLoc(), "subview with dynamic size or offset is not supported");
- }
-
- int elementsPerByte = dstBits / srcBits;
- if (offset % elementsPerByte != 0) {
- return rewriter.notifyMatchFailure(
- op->getLoc(),
- "subview with offset not multiple of elementsPerByte is not "
- "supported");
- }
-
- size = ceilDiv(size, elementsPerByte);
- offset = offset / elementsPerByte;
-
- rewriter.replaceOpWithNewOp<memref::SubViewOp>(
- op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
- op.getStaticStrides());
- return success();
+ return convertCastingOp(rewriter, adaptor, op, newTy);
}
};
@@ -291,9 +344,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+ ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
+ typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably need a lit test for the reinterpret_cast
support.
Oh, I didn't realize somehow the test got lost in this branch. I'll add it back |
f46ab4e
to
b72b7cd
Compare
No description provided.