17
17
#include " mlir/Dialect/MemRef/Transforms/Transforms.h"
18
18
#include " mlir/Dialect/MemRef/Utils/MemRefUtils.h"
19
19
#include " mlir/Dialect/Vector/IR/VectorOps.h"
20
+ #include " mlir/IR/OpDefinition.h"
20
21
#include " mlir/Support/MathExtras.h"
21
22
#include " mlir/Transforms/DialectConversion.h"
22
23
#include " llvm/Support/FormatVariadic.h"
@@ -35,36 +36,98 @@ using namespace mlir;
35
36
// / Return the bit offset of the value at position `srcIdx`. For example, if
36
37
// / `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
37
38
// / located at (x % 2) * 4. Because there are two elements in one i8, and one
38
- // / element has 4 bits.
39
+ // / element has 4 bits. If `rightOffset` is true, return the offset from the
40
+ // / right side of the `dstBits` container instead of the left side.
39
41
static Value getOffsetForBitwidth (Location loc, OpFoldResult srcIdx,
40
42
int sourceBits, int targetBits,
41
- OpBuilder &builder) {
43
+ OpBuilder &builder,
44
+ bool rightOffset = false ) {
42
45
assert (targetBits % sourceBits == 0 );
43
46
AffineExpr s0;
44
47
bindSymbols (builder.getContext (), s0);
45
48
int scaleFactor = targetBits / sourceBits;
46
- OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply (
47
- builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
49
+ AffineExpr offsetExpr =
50
+ rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
51
+ : (s0 % scaleFactor) * sourceBits;
52
+ OpFoldResult offsetVal =
53
+ affine::makeComposedFoldedAffineApply (builder, loc, offsetExpr, {srcIdx});
48
54
Value bitOffset = getValueOrCreateConstantIndexOp (builder, loc, offsetVal);
49
55
IntegerType dstType = builder.getIntegerType (targetBits);
50
56
return builder.create <arith::IndexCastOp>(loc, dstType, bitOffset);
51
57
}
52
58
59
+ // / When writing a subbyte size, writing needs to happen atomically in case of
60
+ // / another write happening on the same byte at the same time. To do the write,
61
+ // / we first must clear `dstBits` at the `linearizedIndices` of the subbyte
62
+ // / store. This function returns the appropriate mask for clearing these bits.
63
+ static Value getAtomicWriteMask (Location loc, OpFoldResult linearizedIndices,
64
+ int64_t srcBits, int64_t dstBits,
65
+ Value bitwidthOffset, OpBuilder &builder) {
66
+ auto dstIntegerType = builder.getIntegerType (dstBits);
67
+ auto maskRightAlignedAttr =
68
+ builder.getIntegerAttr (dstIntegerType, (1 << srcBits) - 1 );
69
+ Value maskRightAligned =
70
+ builder
71
+ .create <arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
72
+ .getResult ();
73
+ Value writeMaskInverse =
74
+ builder.create <arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
75
+ auto flipValAttr = builder.getIntegerAttr (dstIntegerType, -1 );
76
+ Value flipVal =
77
+ builder.create <arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
78
+ .getResult ();
79
+ return builder.create <arith::XOrIOp>(loc, writeMaskInverse, flipVal);
80
+ }
81
+
82
+ // / Returns the scaled linearized index based on the `srcBits` and `dstBits`
83
+ // / sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
84
+ // / the returned index has the granularity of `dstBits`
85
+ static Value getIndicesForLoadOrStore (OpBuilder &builder, Location loc,
86
+ OpFoldResult linearizedIndex,
87
+ int64_t srcBits, int64_t dstBits) {
88
+ AffineExpr s0;
89
+ bindSymbols (builder.getContext (), s0);
90
+ int64_t scaler = dstBits / srcBits;
91
+ OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply (
92
+ builder, loc, s0.floorDiv (scaler), {linearizedIndex});
93
+ return getValueOrCreateConstantIndexOp (builder, loc, scaledLinearizedIndices);
94
+ }
95
+
96
+ static OpFoldResult
97
+ getLinearizedSrcIndices (OpBuilder &builder, Location loc, int64_t srcBits,
98
+ const SmallVector<OpFoldResult> &indices,
99
+ Value memref) {
100
+ auto stridedMetadata =
101
+ builder.create <memref::ExtractStridedMetadataOp>(loc, memref);
102
+ OpFoldResult linearizedIndices;
103
+ std::tie (std::ignore, linearizedIndices) =
104
+ memref::getLinearizedMemRefOffsetAndSize (
105
+ builder, loc, srcBits, srcBits,
106
+ stridedMetadata.getConstifiedMixedOffset (),
107
+ stridedMetadata.getConstifiedMixedSizes (),
108
+ stridedMetadata.getConstifiedMixedStrides (), indices);
109
+ return linearizedIndices;
110
+ }
111
+
53
112
namespace {
54
113
55
114
// ===----------------------------------------------------------------------===//
56
115
// ConvertMemRefAlloc
57
116
// ===----------------------------------------------------------------------===//
58
117
59
- struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
60
- using OpConversionPattern::OpConversionPattern;
118
+ template <typename OpTy>
119
+ struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
120
+ using OpConversionPattern<OpTy>::OpConversionPattern;
61
121
62
122
LogicalResult
63
- matchAndRewrite (memref::AllocOp op, OpAdaptor adaptor,
123
+ matchAndRewrite (OpTy op, typename OpTy::Adaptor adaptor,
64
124
ConversionPatternRewriter &rewriter) const override {
65
- auto currentType = op.getMemref ().getType ().cast <MemRefType>();
66
- auto newResultType =
67
- getTypeConverter ()->convertType (op.getType ()).dyn_cast <MemRefType>();
125
+ static_assert (std::is_same<OpTy, memref::AllocOp>() ||
126
+ std::is_same<OpTy, memref::AllocaOp>(),
127
+ " expected only memref::AllocOp or memref::AllocaOp" );
128
+ auto currentType = cast<MemRefType>(op.getMemref ().getType ());
129
+ auto newResultType = dyn_cast<MemRefType>(
130
+ this ->getTypeConverter ()->convertType (op.getType ()));
68
131
if (!newResultType) {
69
132
return rewriter.notifyMatchFailure (
70
133
op->getLoc (),
@@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
73
136
74
137
// Special case zero-rank memrefs.
75
138
if (currentType.getRank () == 0 ) {
76
- rewriter.replaceOpWithNewOp <memref::AllocOp>(
77
- op, newResultType, ValueRange{}, adaptor.getSymbolOperands (),
78
- adaptor.getAlignmentAttr ());
139
+ rewriter.replaceOpWithNewOp <OpTy>(op, newResultType, ValueRange{},
140
+ adaptor.getSymbolOperands (),
141
+ adaptor.getAlignmentAttr ());
79
142
return success ();
80
143
}
81
144
@@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
97
160
rewriter, loc, linearizedMemRefInfo.linearizedSize ));
98
161
}
99
162
100
- rewriter.replaceOpWithNewOp <memref::AllocOp>(
101
- op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands (),
102
- adaptor.getAlignmentAttr ());
163
+ rewriter.replaceOpWithNewOp <OpTy>(op, newResultType, dynamicLinearizedSize,
164
+ adaptor.getSymbolOperands (),
165
+ adaptor.getAlignmentAttr ());
103
166
return success ();
104
167
}
105
168
};
@@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
155
218
bitsLoad = rewriter.create <memref::LoadOp>(loc, adaptor.getMemref (),
156
219
ValueRange{});
157
220
} else {
158
- SmallVector<OpFoldResult> indices =
159
- getAsOpFoldResult (adaptor.getIndices ());
160
-
161
- auto stridedMetadata = rewriter.create <memref::ExtractStridedMetadataOp>(
162
- loc, op.getMemRef ());
163
-
164
221
// Linearize the indices of the original load instruction. Do not account
165
222
// for the scaling yet. This will be accounted for later.
166
- OpFoldResult linearizedIndices;
167
- std::tie (std::ignore, linearizedIndices) =
168
- memref::getLinearizedMemRefOffsetAndSize (
169
- rewriter, loc, srcBits, srcBits,
170
- stridedMetadata.getConstifiedMixedOffset (),
171
- stridedMetadata.getConstifiedMixedSizes (),
172
- stridedMetadata.getConstifiedMixedStrides (), indices);
173
-
174
- AffineExpr s0;
175
- bindSymbols (rewriter.getContext (), s0);
176
- int64_t scaler = dstBits / srcBits;
177
- OpFoldResult scaledLinearizedIndices =
178
- affine::makeComposedFoldedAffineApply (
179
- rewriter, loc, s0.floorDiv (scaler), {linearizedIndices});
223
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices (
224
+ rewriter, loc, srcBits, adaptor.getIndices (), op.getMemRef ());
225
+
180
226
Value newLoad = rewriter.create <memref::LoadOp>(
181
227
loc, adaptor.getMemref (),
182
- getValueOrCreateConstantIndexOp (rewriter, loc,
183
- scaledLinearizedIndices ));
228
+ getIndicesForLoadOrStore (rewriter, loc, linearizedIndices, srcBits ,
229
+ dstBits ));
184
230
185
231
// Get the offset and shift the bits to the rightmost.
186
232
// Note, currently only the big-endian is supported.
@@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
211
257
}
212
258
};
213
259
260
+ // ===----------------------------------------------------------------------===//
261
+ // ConvertMemRefReinterpretCast
262
+ // ===----------------------------------------------------------------------===//
263
+
264
+ // / Currently there is very limited support for memref::ReinterpretCastOp
265
+ // / conversion. Only the 0 dimensional case is supported.
266
+ struct ConvertMemRefReinterpretCast final
267
+ : OpConversionPattern<memref::ReinterpretCastOp> {
268
+ using OpConversionPattern::OpConversionPattern;
269
+
270
+ LogicalResult
271
+ matchAndRewrite (memref::ReinterpretCastOp op, OpAdaptor adaptor,
272
+ ConversionPatternRewriter &rewriter) const override {
273
+ MemRefType newTy =
274
+ dyn_cast<MemRefType>(getTypeConverter ()->convertType (op.getType ()));
275
+ if (!newTy) {
276
+ return rewriter.notifyMatchFailure (
277
+ op->getLoc (),
278
+ llvm::formatv (" failed to convert memref type: {0}" , op.getType ()));
279
+ }
280
+
281
+ auto convertedElementType = newTy.getElementType ();
282
+ auto oldElementType = op.getType ().getElementType ();
283
+ int srcBits = oldElementType.getIntOrFloatBitWidth ();
284
+ int dstBits = convertedElementType.getIntOrFloatBitWidth ();
285
+ if (dstBits % srcBits != 0 ) {
286
+ return rewriter.notifyMatchFailure (
287
+ op, " only dstBits % srcBits == 0 supported" );
288
+ }
289
+
290
+ // Only support offset for 0-D subview.
291
+ if (op.getType ().getRank () != 0 ) {
292
+ return rewriter.notifyMatchFailure (
293
+ op->getLoc (), " subview with rank > 0 is not supported" );
294
+ }
295
+
296
+ int64_t offset = op.getStaticOffset (0 );
297
+ // Only support static sizes and offsets.
298
+ if (offset == ShapedType::kDynamic ) {
299
+ return rewriter.notifyMatchFailure (
300
+ op->getLoc (), " subview with dynamic offset is not supported" );
301
+ }
302
+
303
+ int elementsPerByte = dstBits / srcBits;
304
+ if (offset % elementsPerByte != 0 ) {
305
+ return rewriter.notifyMatchFailure (
306
+ op->getLoc (),
307
+ " subview with offset not multiple of elementsPerByte is not "
308
+ " supported" );
309
+ }
310
+
311
+ offset = offset / elementsPerByte;
312
+
313
+ rewriter.replaceOpWithNewOp <memref::ReinterpretCastOp>(
314
+ op, newTy, *adaptor.getODSOperands (0 ).begin (), offset,
315
+ SmallVector<int64_t >{}, op.getStaticStrides ());
316
+ return success ();
317
+ }
318
+ };
319
+
320
+ // ===----------------------------------------------------------------------===//
321
+ // ConvertMemrefStore
322
+ // ===----------------------------------------------------------------------===//
323
+
324
+ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
325
+ using OpConversionPattern::OpConversionPattern;
326
+
327
+ LogicalResult
328
+ matchAndRewrite (memref::StoreOp op, OpAdaptor adaptor,
329
+ ConversionPatternRewriter &rewriter) const override {
330
+ auto convertedType = adaptor.getMemref ().getType ().cast <MemRefType>();
331
+ auto convertedElementType = convertedType.getElementType ();
332
+ auto oldElementType = op.getMemRefType ().getElementType ();
333
+ int srcBits = oldElementType.getIntOrFloatBitWidth ();
334
+ int dstBits = convertedElementType.getIntOrFloatBitWidth ();
335
+ auto dstIntegerType = rewriter.getIntegerType (dstBits);
336
+ if (dstBits % srcBits != 0 ) {
337
+ return rewriter.notifyMatchFailure (
338
+ op, " only dstBits % srcBits == 0 supported" );
339
+ }
340
+
341
+ Location loc = op.getLoc ();
342
+ Value extendedInput = rewriter.create <arith::ExtUIOp>(loc, dstIntegerType,
343
+ adaptor.getValue ());
344
+
345
+ // Special case 0-rank memref stores. We can compute the mask at compile
346
+ // time.
347
+ if (convertedType.getRank () == 0 ) {
348
+ // Shift extended value to be left aligned
349
+ auto shiftValAttr =
350
+ rewriter.getIntegerAttr (dstIntegerType, dstBits - srcBits);
351
+ Value shiftVal =
352
+ rewriter.create <arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
353
+ .getResult ();
354
+ Value alignedVal =
355
+ rewriter.create <arith::ShLIOp>(loc, extendedInput, shiftVal)
356
+ .getResult ();
357
+ // Create mask to clear destination bits
358
+ auto writeMaskValAttr = rewriter.getIntegerAttr (
359
+ dstIntegerType, (1 << (dstBits - srcBits)) - 1 );
360
+ Value writeMask =
361
+ rewriter
362
+ .create <arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
363
+ .getResult ();
364
+
365
+ // Clear destination bits
366
+ rewriter.create <memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
367
+ writeMask, adaptor.getMemref (),
368
+ ValueRange{});
369
+ // Write srcs bits to destination
370
+ rewriter.create <memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
371
+ alignedVal, adaptor.getMemref (),
372
+ ValueRange{});
373
+ rewriter.eraseOp (op);
374
+ return success ();
375
+ }
376
+
377
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices (
378
+ rewriter, loc, srcBits, adaptor.getIndices (), op.getMemRef ());
379
+ Value storeIndices = getIndicesForLoadOrStore (
380
+ rewriter, loc, linearizedIndices, srcBits, dstBits);
381
+ Value bitwidthOffset = getOffsetForBitwidth (loc, linearizedIndices, srcBits,
382
+ dstBits, rewriter, true );
383
+ Value writeMask = getAtomicWriteMask (loc, linearizedIndices, srcBits,
384
+ dstBits, bitwidthOffset, rewriter);
385
+ // Align the value to write with the destination bits
386
+ Value alignedVal =
387
+ rewriter.create <arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
388
+ .getResult ();
389
+
390
+ // Clear destination bits
391
+ rewriter.create <memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
392
+ writeMask, adaptor.getMemref (),
393
+ storeIndices);
394
+ // Write srcs bits to destination
395
+ rewriter.create <memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
396
+ alignedVal, adaptor.getMemref (),
397
+ storeIndices);
398
+
399
+ rewriter.eraseOp (op);
400
+ return success ();
401
+ }
402
+ };
403
+
214
404
// ===----------------------------------------------------------------------===//
215
405
// ConvertMemRefSubview
216
406
// ===----------------------------------------------------------------------===//
@@ -291,8 +481,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
291
481
RewritePatternSet &patterns) {
292
482
293
483
// Populate `memref.*` conversion patterns.
294
- patterns.add <ConvertMemRefAlloc, ConvertMemRefLoad,
295
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
484
+ patterns.add <ConvertMemRefAlloc<memref::AllocOp>,
485
+ ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
486
+ ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
487
+ ConvertMemrefStore, ConvertMemRefReinterpretCast>(
296
488
typeConverter, patterns.getContext ());
297
489
memref::populateResolveExtractStridedMetadataPatterns (patterns);
298
490
}
0 commit comments