-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add support for memref.alloca
sub-byte emulation
#71956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref Author: None (Max191) ChangesFull diff: https://github.com/llvm/llvm-project/pull/71956.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..9b197dc51265d92 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -56,15 +56,19 @@ namespace {
// ConvertMemRefAlloc
//===----------------------------------------------------------------------===//
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto currentType = op.getMemref().getType().cast<MemRefType>();
- auto newResultType =
- getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+ std::is_same<OpTy, memref::AllocaOp>(),
+ "expected only memref::AllocOp or memref::AllocaOp");
+ auto currentType = cast<MemRefType>(op.getMemref().getType());
+ auto newResultType = dyn_cast<MemRefType>(
+ this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -73,9 +77,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
@@ -97,9 +101,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
};
@@ -291,7 +295,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+ patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+ ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..a25b0a668499a23 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,36 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloca() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
|
@llvm/pr-subscribers-mlir Author: None (Max191) ChangesFull diff: https://github.com/llvm/llvm-project/pull/71956.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..9b197dc51265d92 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -56,15 +56,19 @@ namespace {
// ConvertMemRefAlloc
//===----------------------------------------------------------------------===//
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
- using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+ using OpConversionPattern<OpTy>::OpConversionPattern;
LogicalResult
- matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+ matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto currentType = op.getMemref().getType().cast<MemRefType>();
- auto newResultType =
- getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+ static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+ std::is_same<OpTy, memref::AllocaOp>(),
+ "expected only memref::AllocOp or memref::AllocaOp");
+ auto currentType = cast<MemRefType>(op.getMemref().getType());
+ auto newResultType = dyn_cast<MemRefType>(
+ this->getTypeConverter()->convertType(op.getType()));
if (!newResultType) {
return rewriter.notifyMatchFailure(
op->getLoc(),
@@ -73,9 +77,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
// Special case zero-rank memrefs.
if (currentType.getRank() == 0) {
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
@@ -97,9 +101,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
rewriter, loc, linearizedMemRefInfo.linearizedSize));
}
- rewriter.replaceOpWithNewOp<memref::AllocOp>(
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
- adaptor.getAlignmentAttr());
+ rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+ adaptor.getSymbolOperands(),
+ adaptor.getAlignmentAttr());
return success();
}
};
@@ -291,7 +295,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
+ patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+ ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..a25b0a668499a23 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,36 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+ %0 = memref.alloca() : memref<5xi4>
+ %1 = memref.load %0[%arg0] : memref<5xi4>
+ return %1 : i4
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+// CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index
+// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+// CHECK: return %[[TRUNC]]
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+// CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index
+// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+// CHECK32: return %[[TRUNC]]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
I guess this never got pushed because I don't have write access :P This is now included in #72181, so I'll close this so we don't have a conflict. |
No description provided.