17
17
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
18
18
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
20
+ #include " mlir/IR/BuiltinTypes.h"
21
+ #include " mlir/Support/LogicalResult.h"
20
22
#include " mlir/Support/MathExtras.h"
21
23
#include " mlir/Transforms/DialectConversion.h"
22
24
#include " llvm/Support/FormatVariadic.h"
23
25
#include " llvm/Support/MathExtras.h"
24
26
#include < cassert>
27
+ #include < type_traits>
25
28
26
29
using namespace mlir ;
27
30
28
31
// ===----------------------------------------------------------------------===//
29
32
// Utility functions
30
33
// ===----------------------------------------------------------------------===//
31
34
35
+ // / Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
36
+ // / type. The result MemRefType of the old op must have a rank and stride of 1,
37
+ // / with static offset and size. The number of bits in the offset must evenly
38
+ // / divide the bitwidth of the new converted type.
39
+ template <typename MemRefOpTy>
40
+ static LogicalResult convertCastingOp (ConversionPatternRewriter &rewriter,
41
+ typename MemRefOpTy::Adaptor adaptor,
42
+ MemRefOpTy op, MemRefType newTy) {
43
+ static_assert (std::is_same<MemRefOpTy, memref::SubViewOp>() ||
44
+ std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
45
+ " Expected only memref::SubViewOp or memref::ReinterpretCastOp" );
46
+
47
+ auto convertedElementType = newTy.getElementType ();
48
+ auto oldElementType = op.getType ().getElementType ();
49
+ int srcBits = oldElementType.getIntOrFloatBitWidth ();
50
+ int dstBits = convertedElementType.getIntOrFloatBitWidth ();
51
+ if (dstBits % srcBits != 0 ) {
52
+ return rewriter.notifyMatchFailure (op,
53
+ " only dstBits % srcBits == 0 supported" );
54
+ }
55
+
56
+ // Only support stride of 1.
57
+ if (llvm::any_of (op.getStaticStrides (),
58
+ [](int64_t stride) { return stride != 1 ; })) {
59
+ return rewriter.notifyMatchFailure (op->getLoc (),
60
+ " stride != 1 is not supported" );
61
+ }
62
+
63
+ auto sizes = op.getStaticSizes ();
64
+ int64_t offset = op.getStaticOffset (0 );
65
+ // Only support static sizes and offsets.
66
+ if (llvm::any_of (sizes,
67
+ [](int64_t size) { return size == ShapedType::kDynamic ; }) ||
68
+ offset == ShapedType::kDynamic ) {
69
+ return rewriter.notifyMatchFailure (
70
+ op->getLoc (), " dynamic size or offset is not supported" );
71
+ }
72
+
73
+ int elementsPerByte = dstBits / srcBits;
74
+ if (offset % elementsPerByte != 0 ) {
75
+ return rewriter.notifyMatchFailure (
76
+ op->getLoc (), " offset not multiple of elementsPerByte is not "
77
+ " supported" );
78
+ }
79
+
80
+ SmallVector<int64_t > size;
81
+ if (sizes.size ())
82
+ size.push_back (ceilDiv (sizes[0 ], elementsPerByte));
83
+ offset = offset / elementsPerByte;
84
+
85
+ rewriter.replaceOpWithNewOp <MemRefOpTy>(op, newTy,
86
+ *adaptor.getODSOperands (0 ).begin (),
87
+ offset, size, op.getStaticStrides ());
88
+ return success ();
89
+ }
90
+
32
91
// / When data is loaded/stored in `targetBits` granularity, but is used in
33
92
// / `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
34
93
// / treated as an array of elements of width `sourceBits`.
@@ -211,6 +270,37 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
211
270
}
212
271
};
213
272
273
+ // ===----------------------------------------------------------------------===//
274
+ // ConvertMemRefReinterpretCast
275
+ // ===----------------------------------------------------------------------===//
276
+
277
+ // / Output types should be at most one dimensional, so only the 0 or 1
278
+ // / dimensional cases are supported.
279
+ struct ConvertMemRefReinterpretCast final
280
+ : OpConversionPattern<memref::ReinterpretCastOp> {
281
+ using OpConversionPattern::OpConversionPattern;
282
+
283
+ LogicalResult
284
+ matchAndRewrite (memref::ReinterpretCastOp op, OpAdaptor adaptor,
285
+ ConversionPatternRewriter &rewriter) const override {
286
+ MemRefType newTy =
287
+ dyn_cast<MemRefType>(getTypeConverter ()->convertType (op.getType ()));
288
+ if (!newTy) {
289
+ return rewriter.notifyMatchFailure (
290
+ op->getLoc (),
291
+ llvm::formatv (" failed to convert memref type: {0}" , op.getType ()));
292
+ }
293
+
294
+ // Only support for 0 or 1 dimensional cases.
295
+ if (op.getType ().getRank () > 1 ) {
296
+ return rewriter.notifyMatchFailure (
297
+ op->getLoc (), " subview with rank > 1 is not supported" );
298
+ }
299
+
300
+ return convertCastingOp (rewriter, adaptor, op, newTy);
301
+ }
302
+ };
303
+
214
304
// ===----------------------------------------------------------------------===//
215
305
// ConvertMemRefSubview
216
306
// ===----------------------------------------------------------------------===//
@@ -233,50 +323,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
233
323
llvm::formatv (" failed to convert memref type: {0}" , op.getType ()));
234
324
}
235
325
236
- auto convertedElementType = newTy.getElementType ();
237
- auto oldElementType = op.getType ().getElementType ();
238
- int srcBits = oldElementType.getIntOrFloatBitWidth ();
239
- int dstBits = convertedElementType.getIntOrFloatBitWidth ();
240
- if (dstBits % srcBits != 0 ) {
241
- return rewriter.notifyMatchFailure (
242
- op, " only dstBits % srcBits == 0 supported" );
243
- }
244
-
245
326
// Only support offset for 1-D subview.
246
327
if (op.getType ().getRank () != 1 ) {
247
328
return rewriter.notifyMatchFailure (
248
329
op->getLoc (), " subview with rank > 1 is not supported" );
249
330
}
250
331
251
- // Only support stride of 1.
252
- if (op.getStaticStride (0 ) != 1 ) {
253
- return rewriter.notifyMatchFailure (
254
- op->getLoc (), " subview with stride != 1 is not supported" );
255
- }
256
-
257
- int64_t size = op.getStaticSize (0 );
258
- int64_t offset = op.getStaticOffset (0 );
259
- // Only support static sizes and offsets.
260
- if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic ) {
261
- return rewriter.notifyMatchFailure (
262
- op->getLoc (), " subview with dynamic size or offset is not supported" );
263
- }
264
-
265
- int elementsPerByte = dstBits / srcBits;
266
- if (offset % elementsPerByte != 0 ) {
267
- return rewriter.notifyMatchFailure (
268
- op->getLoc (),
269
- " subview with offset not multiple of elementsPerByte is not "
270
- " supported" );
271
- }
272
-
273
- size = ceilDiv (size, elementsPerByte);
274
- offset = offset / elementsPerByte;
275
-
276
- rewriter.replaceOpWithNewOp <memref::SubViewOp>(
277
- op, newTy, *adaptor.getODSOperands (0 ).begin (), offset, size,
278
- op.getStaticStrides ());
279
- return success ();
332
+ return convertCastingOp (rewriter, adaptor, op, newTy);
280
333
}
281
334
};
282
335
@@ -291,9 +344,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
291
344
RewritePatternSet &patterns) {
292
345
293
346
// Populate `memref.*` conversion patterns.
294
- patterns.add <ConvertMemRefAlloc, ConvertMemRefLoad,
295
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
296
- typeConverter, patterns.getContext ());
347
+ patterns
348
+ .add <ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
349
+ ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
350
+ typeConverter, patterns.getContext ());
297
351
memref::populateResolveExtractStridedMetadataPatterns (patterns);
298
352
}
299
353
0 commit comments