Skip to content

Commit 1b4a88b

Browse files
committed
[mlir] Add narrow type emulation conversions
Adds narrow type emulation support for: - `memref.alloca` - `memref.store` - `memref.reinterpret_cast`
1 parent f2bf44b commit 1b4a88b

File tree

2 files changed

+460
-40
lines changed

2 files changed

+460
-40
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 232 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919
#include "mlir/Dialect/Vector/IR/VectorOps.h"
20+
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/Support/MathExtras.h"
2122
#include "mlir/Transforms/DialectConversion.h"
2223
#include "llvm/Support/FormatVariadic.h"
@@ -35,36 +36,98 @@ using namespace mlir;
3536
/// Return the bit offset of the value at position `srcIdx`. For example, if
3637
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
3738
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
38-
/// element has 4 bits.
39+
/// element has 4 bits. If `rightOffset` is true, return the offset from the
40+
/// right side of the `dstBits` container instead of the left side.
3941
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
4042
int sourceBits, int targetBits,
41-
OpBuilder &builder) {
43+
OpBuilder &builder,
44+
bool rightOffset = false) {
4245
assert(targetBits % sourceBits == 0);
4346
AffineExpr s0;
4447
bindSymbols(builder.getContext(), s0);
4548
int scaleFactor = targetBits / sourceBits;
46-
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
47-
builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
49+
AffineExpr offsetExpr =
50+
rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
51+
: (s0 % scaleFactor) * sourceBits;
52+
OpFoldResult offsetVal =
53+
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
4854
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
4955
IntegerType dstType = builder.getIntegerType(targetBits);
5056
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
5157
}
5258

59+
/// When writing a subbyte size, writing needs to happen atomically in case of
60+
/// another write happening on the same byte at the same time. To do the write,
61+
/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
62+
/// store. This function returns the appropriate mask for clearing these bits.
63+
static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
64+
int64_t srcBits, int64_t dstBits,
65+
Value bitwidthOffset, OpBuilder &builder) {
66+
auto dstIntegerType = builder.getIntegerType(dstBits);
67+
auto maskRightAlignedAttr =
68+
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
69+
Value maskRightAligned =
70+
builder
71+
.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
72+
.getResult();
73+
Value writeMaskInverse =
74+
builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
75+
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
76+
Value flipVal =
77+
builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
78+
.getResult();
79+
return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
80+
}
81+
82+
/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
83+
/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
84+
/// the returned index has the granularity of `dstBits`
85+
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
86+
OpFoldResult linearizedIndex,
87+
int64_t srcBits, int64_t dstBits) {
88+
AffineExpr s0;
89+
bindSymbols(builder.getContext(), s0);
90+
int64_t scaler = dstBits / srcBits;
91+
OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
92+
builder, loc, s0.floorDiv(scaler), {linearizedIndex});
93+
return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
94+
}
95+
96+
static OpFoldResult
97+
getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
98+
const SmallVector<OpFoldResult> &indices,
99+
Value memref) {
100+
auto stridedMetadata =
101+
builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
102+
OpFoldResult linearizedIndices;
103+
std::tie(std::ignore, linearizedIndices) =
104+
memref::getLinearizedMemRefOffsetAndSize(
105+
builder, loc, srcBits, srcBits,
106+
stridedMetadata.getConstifiedMixedOffset(),
107+
stridedMetadata.getConstifiedMixedSizes(),
108+
stridedMetadata.getConstifiedMixedStrides(), indices);
109+
return linearizedIndices;
110+
}
111+
53112
namespace {
54113

55114
//===----------------------------------------------------------------------===//
56115
// ConvertMemRefAlloc
57116
//===----------------------------------------------------------------------===//
58117

59-
struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
60-
using OpConversionPattern::OpConversionPattern;
118+
template <typename OpTy>
119+
struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
120+
using OpConversionPattern<OpTy>::OpConversionPattern;
61121

62122
LogicalResult
63-
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
123+
matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
64124
ConversionPatternRewriter &rewriter) const override {
65-
auto currentType = op.getMemref().getType().cast<MemRefType>();
66-
auto newResultType =
67-
getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
125+
static_assert(std::is_same<OpTy, memref::AllocOp>() ||
126+
std::is_same<OpTy, memref::AllocaOp>(),
127+
"expected only memref::AllocOp or memref::AllocaOp");
128+
auto currentType = cast<MemRefType>(op.getMemref().getType());
129+
auto newResultType = dyn_cast<MemRefType>(
130+
this->getTypeConverter()->convertType(op.getType()));
68131
if (!newResultType) {
69132
return rewriter.notifyMatchFailure(
70133
op->getLoc(),
@@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
73136

74137
// Special case zero-rank memrefs.
75138
if (currentType.getRank() == 0) {
76-
rewriter.replaceOpWithNewOp<memref::AllocOp>(
77-
op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
78-
adaptor.getAlignmentAttr());
139+
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
140+
adaptor.getSymbolOperands(),
141+
adaptor.getAlignmentAttr());
79142
return success();
80143
}
81144

@@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
97160
rewriter, loc, linearizedMemRefInfo.linearizedSize));
98161
}
99162

100-
rewriter.replaceOpWithNewOp<memref::AllocOp>(
101-
op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
102-
adaptor.getAlignmentAttr());
163+
rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
164+
adaptor.getSymbolOperands(),
165+
adaptor.getAlignmentAttr());
103166
return success();
104167
}
105168
};
@@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
155218
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
156219
ValueRange{});
157220
} else {
158-
SmallVector<OpFoldResult> indices =
159-
getAsOpFoldResult(adaptor.getIndices());
160-
161-
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
162-
loc, op.getMemRef());
163-
164221
// Linearize the indices of the original load instruction. Do not account
165222
// for the scaling yet. This will be accounted for later.
166-
OpFoldResult linearizedIndices;
167-
std::tie(std::ignore, linearizedIndices) =
168-
memref::getLinearizedMemRefOffsetAndSize(
169-
rewriter, loc, srcBits, srcBits,
170-
stridedMetadata.getConstifiedMixedOffset(),
171-
stridedMetadata.getConstifiedMixedSizes(),
172-
stridedMetadata.getConstifiedMixedStrides(), indices);
173-
174-
AffineExpr s0;
175-
bindSymbols(rewriter.getContext(), s0);
176-
int64_t scaler = dstBits / srcBits;
177-
OpFoldResult scaledLinearizedIndices =
178-
affine::makeComposedFoldedAffineApply(
179-
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
223+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
224+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
225+
180226
Value newLoad = rewriter.create<memref::LoadOp>(
181227
loc, adaptor.getMemref(),
182-
getValueOrCreateConstantIndexOp(rewriter, loc,
183-
scaledLinearizedIndices));
228+
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
229+
dstBits));
184230

185231
// Get the offset and shift the bits to the rightmost.
186232
// Note, currently only the big-endian is supported.
@@ -211,6 +257,150 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
211257
}
212258
};
213259

260+
//===----------------------------------------------------------------------===//
261+
// ConvertMemRefReinterpretCast
262+
//===----------------------------------------------------------------------===//
263+
264+
/// Currently there is very limited support for memref::ReinterpretCastOp
265+
/// conversion. Only the 0 dimensional case is supported.
266+
struct ConvertMemRefReinterpretCast final
267+
: OpConversionPattern<memref::ReinterpretCastOp> {
268+
using OpConversionPattern::OpConversionPattern;
269+
270+
LogicalResult
271+
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
272+
ConversionPatternRewriter &rewriter) const override {
273+
MemRefType newTy =
274+
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
275+
if (!newTy) {
276+
return rewriter.notifyMatchFailure(
277+
op->getLoc(),
278+
llvm::formatv("failed to convert memref type: {0}", op.getType()));
279+
}
280+
281+
auto convertedElementType = newTy.getElementType();
282+
auto oldElementType = op.getType().getElementType();
283+
int srcBits = oldElementType.getIntOrFloatBitWidth();
284+
int dstBits = convertedElementType.getIntOrFloatBitWidth();
285+
if (dstBits % srcBits != 0) {
286+
return rewriter.notifyMatchFailure(
287+
op, "only dstBits % srcBits == 0 supported");
288+
}
289+
290+
// Only support offset for 0-D subview.
291+
if (op.getType().getRank() != 0) {
292+
return rewriter.notifyMatchFailure(
293+
op->getLoc(), "subview with rank > 0 is not supported");
294+
}
295+
296+
int64_t offset = op.getStaticOffset(0);
297+
// Only support static sizes and offsets.
298+
if (offset == ShapedType::kDynamic) {
299+
return rewriter.notifyMatchFailure(
300+
op->getLoc(), "subview with dynamic offset is not supported");
301+
}
302+
303+
int elementsPerByte = dstBits / srcBits;
304+
if (offset % elementsPerByte != 0) {
305+
return rewriter.notifyMatchFailure(
306+
op->getLoc(),
307+
"subview with offset not multiple of elementsPerByte is not "
308+
"supported");
309+
}
310+
311+
offset = offset / elementsPerByte;
312+
313+
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
314+
op, newTy, *adaptor.getODSOperands(0).begin(), offset,
315+
SmallVector<int64_t>{}, op.getStaticStrides());
316+
return success();
317+
}
318+
};
319+
320+
//===----------------------------------------------------------------------===//
321+
// ConvertMemrefStore
322+
//===----------------------------------------------------------------------===//
323+
324+
struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
325+
using OpConversionPattern::OpConversionPattern;
326+
327+
LogicalResult
328+
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
329+
ConversionPatternRewriter &rewriter) const override {
330+
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
331+
auto convertedElementType = convertedType.getElementType();
332+
auto oldElementType = op.getMemRefType().getElementType();
333+
int srcBits = oldElementType.getIntOrFloatBitWidth();
334+
int dstBits = convertedElementType.getIntOrFloatBitWidth();
335+
auto dstIntegerType = rewriter.getIntegerType(dstBits);
336+
if (dstBits % srcBits != 0) {
337+
return rewriter.notifyMatchFailure(
338+
op, "only dstBits % srcBits == 0 supported");
339+
}
340+
341+
Location loc = op.getLoc();
342+
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
343+
adaptor.getValue());
344+
345+
// Special case 0-rank memref stores. We can compute the mask at compile
346+
// time.
347+
if (convertedType.getRank() == 0) {
348+
// Shift extended value to be left aligned
349+
auto shiftValAttr =
350+
rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
351+
Value shiftVal =
352+
rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr)
353+
.getResult();
354+
Value alignedVal =
355+
rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal)
356+
.getResult();
357+
// Create mask to clear destination bits
358+
auto writeMaskValAttr = rewriter.getIntegerAttr(
359+
dstIntegerType, (1 << (dstBits - srcBits)) - 1);
360+
Value writeMask =
361+
rewriter
362+
.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr)
363+
.getResult();
364+
365+
// Clear destination bits
366+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
367+
writeMask, adaptor.getMemref(),
368+
ValueRange{});
369+
// Write srcs bits to destination
370+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
371+
alignedVal, adaptor.getMemref(),
372+
ValueRange{});
373+
rewriter.eraseOp(op);
374+
return success();
375+
}
376+
377+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
378+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
379+
Value storeIndices = getIndicesForLoadOrStore(
380+
rewriter, loc, linearizedIndices, srcBits, dstBits);
381+
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
382+
dstBits, rewriter, true);
383+
Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
384+
dstBits, bitwidthOffset, rewriter);
385+
// Align the value to write with the destination bits
386+
Value alignedVal =
387+
rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
388+
.getResult();
389+
390+
// Clear destination bits
391+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
392+
writeMask, adaptor.getMemref(),
393+
storeIndices);
394+
// Write srcs bits to destination
395+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
396+
alignedVal, adaptor.getMemref(),
397+
storeIndices);
398+
399+
rewriter.eraseOp(op);
400+
return success();
401+
}
402+
};
403+
214404
//===----------------------------------------------------------------------===//
215405
// ConvertMemRefSubview
216406
//===----------------------------------------------------------------------===//
@@ -291,8 +481,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
291481
RewritePatternSet &patterns) {
292482

293483
// Populate `memref.*` conversion patterns.
294-
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
295-
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
484+
patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
485+
ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
486+
ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
487+
ConvertMemrefStore, ConvertMemRefReinterpretCast>(
296488
typeConverter, patterns.getContext());
297489
memref::populateResolveExtractStridedMetadataPatterns(patterns);
298490
}

0 commit comments

Comments
 (0)