Skip to content

Commit 7c1a7f7

Browse files
committed
[mlir] Add support for memref.alloca sub-byte emulation
1 parent 7bfcce3 commit 7c1a7f7

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,22 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
112112
namespace {
113113

114114
//===----------------------------------------------------------------------===//
115-
// ConvertMemRefAlloc
115+
// ConvertMemRefAllocation
116116
//===----------------------------------------------------------------------===//
117117

118-
struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
119-
using OpConversionPattern::OpConversionPattern;
118+
template <typename OpTy>
119+
struct ConvertMemRefAllocation final : OpConversionPattern<OpTy> {
120+
using OpConversionPattern<OpTy>::OpConversionPattern;
120121

121122
LogicalResult
122-
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
123+
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
123124
ConversionPatternRewriter &rewriter) const override {
124-
auto currentType = op.getMemref().getType().cast<MemRefType>();
125-
auto newResultType =
126-
getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
125+
static_assert(std::is_same<OpTy, memref::AllocOp>() ||
126+
std::is_same<OpTy, memref::AllocaOp>(),
127+
"expected only memref::AllocOp or memref::AllocaOp");
128+
auto currentType = cast<MemRefType>(op.getMemref().getType());
129+
auto newResultType = dyn_cast<MemRefType>(
130+
this->getTypeConverter()->convertType(op.getType()));
127131
if (!newResultType) {
128132
return rewriter.notifyMatchFailure(
129133
op->getLoc(),
@@ -132,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
132136

133137
// Special case zero-rank memrefs.
134138
if (currentType.getRank() == 0) {
135-
rewriter.replaceOpWithNewOp<memref::AllocOp>(
136-
op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
137-
adaptor.getAlignmentAttr());
139+
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
140+
adaptor.getSymbolOperands(),
141+
adaptor.getAlignmentAttr());
138142
return success();
139143
}
140144

@@ -156,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
156160
rewriter, loc, linearizedMemRefInfo.linearizedSize));
157161
}
158162

159-
rewriter.replaceOpWithNewOp<memref::AllocOp>(
160-
op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
161-
adaptor.getAlignmentAttr());
163+
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
164+
adaptor.getSymbolOperands(),
165+
adaptor.getAlignmentAttr());
162166
return success();
163167
}
164168
};
@@ -344,10 +348,11 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
344348
RewritePatternSet &patterns) {
345349

346350
// Populate `memref.*` conversion patterns.
347-
patterns
348-
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
349-
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
350-
typeConverter, patterns.getContext());
351+
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
352+
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
353+
ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
354+
ConvertMemRefReinterpretCast>(typeConverter,
355+
patterns.getContext());
351356
memref::populateResolveExtractStridedMetadataPatterns(patterns);
352357
}
353358

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,36 @@ func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
232232
// CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
233233
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
234234
// CHECK32: return %[[TRUNC]]
235+
236+
// -----
237+
238+
func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
239+
%0 = memref.alloca() : memref<5xi4>
240+
%1 = memref.load %0[%arg0] : memref<5xi4>
241+
return %1 : i4
242+
}
243+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
244+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
245+
// CHECK: func @memref_alloca_load_i4(
246+
// CHECK-SAME: %[[ARG0:.+]]: index
247+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
248+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
249+
// CHECK: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
250+
// CHECK: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
251+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
252+
// CHECK: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
253+
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
254+
// CHECK: return %[[TRUNC]]
255+
256+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
257+
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
258+
// CHECK32: func @memref_alloca_load_i4(
259+
// CHECK32-SAME: %[[ARG0:.+]]: index
260+
// CHECK32: %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
261+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
262+
// CHECK32: %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
263+
// CHECK32: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
264+
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
265+
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
266+
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
267+
// CHECK32: return %[[TRUNC]]

0 commit comments

Comments
 (0)