Skip to content

Commit b29332a

Browse files
authored
[mlir] Add narrow type emulation for memref.reinterpret_cast (#73144)
1 parent 30afb21 commit b29332a

File tree

2 files changed

+153
-41
lines changed

2 files changed

+153
-41
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 95 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,77 @@
1717
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919
#include "mlir/Dialect/Vector/IR/VectorOps.h"
20+
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/Support/LogicalResult.h"
2022
#include "mlir/Support/MathExtras.h"
2123
#include "mlir/Transforms/DialectConversion.h"
2224
#include "llvm/Support/FormatVariadic.h"
2325
#include "llvm/Support/MathExtras.h"
2426
#include <cassert>
27+
#include <type_traits>
2528

2629
using namespace mlir;
2730

2831
//===----------------------------------------------------------------------===//
2932
// Utility functions
3033
//===----------------------------------------------------------------------===//
3134

35+
/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
36+
/// type. The result MemRefType of the old op must have a rank and stride of 1,
37+
/// with static offset and size. The number of bits in the offset must evenly
38+
/// divide the bitwidth of the new converted type.
39+
template <typename MemRefOpTy>
40+
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
41+
typename MemRefOpTy::Adaptor adaptor,
42+
MemRefOpTy op, MemRefType newTy) {
43+
static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
44+
std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
45+
"Expected only memref::SubViewOp or memref::ReinterpretCastOp");
46+
47+
auto convertedElementType = newTy.getElementType();
48+
auto oldElementType = op.getType().getElementType();
49+
int srcBits = oldElementType.getIntOrFloatBitWidth();
50+
int dstBits = convertedElementType.getIntOrFloatBitWidth();
51+
if (dstBits % srcBits != 0) {
52+
return rewriter.notifyMatchFailure(op,
53+
"only dstBits % srcBits == 0 supported");
54+
}
55+
56+
// Only support stride of 1.
57+
if (llvm::any_of(op.getStaticStrides(),
58+
[](int64_t stride) { return stride != 1; })) {
59+
return rewriter.notifyMatchFailure(op->getLoc(),
60+
"stride != 1 is not supported");
61+
}
62+
63+
auto sizes = op.getStaticSizes();
64+
int64_t offset = op.getStaticOffset(0);
65+
// Only support static sizes and offsets.
66+
if (llvm::any_of(sizes,
67+
[](int64_t size) { return size == ShapedType::kDynamic; }) ||
68+
offset == ShapedType::kDynamic) {
69+
return rewriter.notifyMatchFailure(
70+
op->getLoc(), "dynamic size or offset is not supported");
71+
}
72+
73+
int elementsPerByte = dstBits / srcBits;
74+
if (offset % elementsPerByte != 0) {
75+
return rewriter.notifyMatchFailure(
76+
op->getLoc(), "offset not multiple of elementsPerByte is not "
77+
"supported");
78+
}
79+
80+
SmallVector<int64_t> size;
81+
if (sizes.size())
82+
size.push_back(ceilDiv(sizes[0], elementsPerByte));
83+
offset = offset / elementsPerByte;
84+
85+
rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
86+
*adaptor.getODSOperands(0).begin(),
87+
offset, size, op.getStaticStrides());
88+
return success();
89+
}
90+
3291
/// When data is loaded/stored in `targetBits` granularity, but is used in
3392
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
3493
/// treated as an array of elements of width `sourceBits`.
@@ -211,6 +270,37 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
211270
}
212271
};
213272

273+
//===----------------------------------------------------------------------===//
274+
// ConvertMemRefReinterpretCast
275+
//===----------------------------------------------------------------------===//
276+
277+
/// Output types should be at most one dimensional, so only the 0 or 1
278+
/// dimensional cases are supported.
279+
struct ConvertMemRefReinterpretCast final
280+
: OpConversionPattern<memref::ReinterpretCastOp> {
281+
using OpConversionPattern::OpConversionPattern;
282+
283+
LogicalResult
284+
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
285+
ConversionPatternRewriter &rewriter) const override {
286+
MemRefType newTy =
287+
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
288+
if (!newTy) {
289+
return rewriter.notifyMatchFailure(
290+
op->getLoc(),
291+
llvm::formatv("failed to convert memref type: {0}", op.getType()));
292+
}
293+
294+
// Only support for 0 or 1 dimensional cases.
295+
if (op.getType().getRank() > 1) {
296+
return rewriter.notifyMatchFailure(
297+
op->getLoc(), "subview with rank > 1 is not supported");
298+
}
299+
300+
return convertCastingOp(rewriter, adaptor, op, newTy);
301+
}
302+
};
303+
214304
//===----------------------------------------------------------------------===//
215305
// ConvertMemRefSubview
216306
//===----------------------------------------------------------------------===//
@@ -233,50 +323,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
233323
llvm::formatv("failed to convert memref type: {0}", op.getType()));
234324
}
235325

236-
auto convertedElementType = newTy.getElementType();
237-
auto oldElementType = op.getType().getElementType();
238-
int srcBits = oldElementType.getIntOrFloatBitWidth();
239-
int dstBits = convertedElementType.getIntOrFloatBitWidth();
240-
if (dstBits % srcBits != 0) {
241-
return rewriter.notifyMatchFailure(
242-
op, "only dstBits % srcBits == 0 supported");
243-
}
244-
245326
// Only support offset for 1-D subview.
246327
if (op.getType().getRank() != 1) {
247328
return rewriter.notifyMatchFailure(
248329
op->getLoc(), "subview with rank > 1 is not supported");
249330
}
250331

251-
// Only support stride of 1.
252-
if (op.getStaticStride(0) != 1) {
253-
return rewriter.notifyMatchFailure(
254-
op->getLoc(), "subview with stride != 1 is not supported");
255-
}
256-
257-
int64_t size = op.getStaticSize(0);
258-
int64_t offset = op.getStaticOffset(0);
259-
// Only support static sizes and offsets.
260-
if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
261-
return rewriter.notifyMatchFailure(
262-
op->getLoc(), "subview with dynamic size or offset is not supported");
263-
}
264-
265-
int elementsPerByte = dstBits / srcBits;
266-
if (offset % elementsPerByte != 0) {
267-
return rewriter.notifyMatchFailure(
268-
op->getLoc(),
269-
"subview with offset not multiple of elementsPerByte is not "
270-
"supported");
271-
}
272-
273-
size = ceilDiv(size, elementsPerByte);
274-
offset = offset / elementsPerByte;
275-
276-
rewriter.replaceOpWithNewOp<memref::SubViewOp>(
277-
op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
278-
op.getStaticStrides());
279-
return success();
332+
return convertCastingOp(rewriter, adaptor, op, newTy);
280333
}
281334
};
282335

@@ -291,9 +344,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
291344
RewritePatternSet &patterns) {
292345

293346
// Populate `memref.*` conversion patterns.
294-
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
295-
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
296-
typeConverter, patterns.getContext());
347+
patterns
348+
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
349+
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
350+
typeConverter, patterns.getContext());
297351
memref::populateResolveExtractStridedMetadataPatterns(patterns);
298352
}
299353

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,3 +174,61 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
174174
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
175175
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
176176
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
177+
178+
// -----
179+
180+
func.func @reinterpret_cast_memref_load_0D() -> i4 {
181+
%0 = memref.alloc() : memref<5xi4>
182+
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
183+
%1 = memref.load %reinterpret_cast_0[] : memref<i4>
184+
return %1 : i4
185+
}
186+
// CHECK-LABEL: func @reinterpret_cast_memref_load_0D()
187+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
188+
// CHECK: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<3xi8> to memref<i8>
189+
// CHECK: %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i8>
190+
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
191+
// CHECK: return %[[TRUNC]]
192+
193+
// CHECK32-LABEL: func @reinterpret_cast_memref_load_0D()
194+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
195+
// CHECK32: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<1xi32> to memref<i32>
196+
// CHECK32: %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i32>
197+
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
198+
// CHECK32: return %[[TRUNC]]
199+
200+
// -----
201+
202+
func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
203+
%0 = memref.alloc() : memref<5x5xi4>
204+
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [8], sizes: [25], strides: [1] : memref<5x5xi4> to memref<25xi4, strided<[1], offset:8>>
205+
%1 = memref.load %reinterpret_cast_0[%arg0] : memref<25xi4, strided<[1], offset:8>>
206+
return %1 : i4
207+
}
208+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
209+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
210+
// CHECK: func @reinterpret_cast_memref_load_1D(
211+
// CHECK-SAME: %[[ARG0:.+]]: index
212+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<13xi8>
213+
// CHECK: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [4], sizes: [13], strides: [1] : memref<13xi8> to memref<13xi8, strided<[1], offset: 4>>
214+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
215+
// CHECK: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<13xi8, strided<[1], offset: 4>>
216+
// CHECK: %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
217+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i8
218+
// CHECK: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i8
219+
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i8 to i4
220+
// CHECK: return %[[TRUNC]]
221+
222+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
223+
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
224+
// CHECK32: func @reinterpret_cast_memref_load_1D(
225+
// CHECK32-SAME: %[[ARG0:.+]]: index
226+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
227+
// CHECK32: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [1], sizes: [4], strides: [1] : memref<4xi32> to memref<4xi32, strided<[1], offset: 1>>
228+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
229+
// CHECK32: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<4xi32, strided<[1], offset: 1>>
230+
// CHECK32: %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
231+
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i32
232+
// CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
233+
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
234+
// CHECK32: return %[[TRUNC]]

0 commit comments

Comments
 (0)