Skip to content

Commit 0f8bab8

Browse files
author
Mahesh Ravishankar
committed
[mlir] Revamp implementation of sub-byte load/store emulation.
When handling sub-byte emulation, the sizes of the converted `memref`s also need to be updated (this was not done in the current implementation). This adds the additional complexity of having to linearize the `memref`s as well. Consider a `memref<3x3xi4>` where the `i4` elements are packed. This has a overall size of 5 bytes (rounded up to number of bytes). This can only be represented by a `memref<5xi8>`. A `memref<3x2xi8>` would imply an implicit padding of 4 bits at the end of each row. So incorporate linearization into the sub-byte load-store emulation. This patch also updates some of the utility functions to make better use of statically available information using `OpFoldResult` and `makeComposedFoldedAffineApplyOps`. Reviewed By: hanchung, yzhang93 Differential Revision: https://reviews.llvm.org/D158125
1 parent 6869786 commit 0f8bab8

File tree

10 files changed

+465
-377
lines changed

10 files changed

+465
-377
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,36 +28,37 @@ namespace memref {
2828
/// contiguous chunk of memory.
2929
bool isStaticShapeAndContiguousRowMajor(MemRefType type);
3030

31-
/// Returns the flattened 1-D memref and linearized offset for narrow type
32-
/// emulation.
33-
///
34-
/// The emulation only works on 1D memref types. To make this work on N-D
35-
/// memref, we need to linearize the offset.
36-
///
37-
/// For example, to emulate i4 to i8, the following op:
38-
///
39-
/// %0 = memref.load %arg0[%v0, %v1] :
40-
/// memref<?x?xi4, strided<[?, ?], offset: ?>>
41-
///
42-
/// can be replaced with
43-
///
44-
/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
45-
///
46-
/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
47-
/// %linearized_size = %size0 * %size1
48-
/// %scaled_linear_offset = %linearized_offset / 8 * 4
49-
/// %scaled_base_offset = %offset / 8 * 4
50-
///
51-
/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
52-
/// sizes = [%linearized_size], strides = [%stride#1]
53-
///
54-
/// %new_load = memref.load %linearized[%scaled_linear_offset] :
55-
/// memref<?xi8, strided<[?], offset: ?>>
56-
std::pair<Value, Value>
57-
getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
58-
int dstBits, SmallVector<Value> indices,
59-
memref::ExtractStridedMetadataOp stridedMetadata,
60-
OpBuilder &builder);
31+
/// For a `memref` with `offset`, `sizes` and `strides`, returns the
32+
/// offset and size to use for the linearized `memref`.
33+
/// - If the linearization is done for emulating load/stores of
34+
/// element type with bitwidth `srcBits` using element type with
35+
/// bitwidth `dstBits`, the linearized offset and size are
36+
/// scaled down by `dstBits`/`srcBits`.
37+
/// - If `indices` is provided, it represents the position in the
38+
/// original `memref` being accessed. The method then returns the
39+
/// index to use in the linearized `memref`. The linearized index
40+
/// is also scaled down by `dstBits`/`srcBits`. If `indices` is not provided
41+
/// 0, is returned for the linearized index.
42+
struct LinearizedMemRefInfo {
43+
OpFoldResult linearizedOffset;
44+
OpFoldResult linearizedSize;
45+
};
46+
std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
47+
OpBuilder &builder, Location loc, int srcBits, int dstBits,
48+
OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
49+
ArrayRef<OpFoldResult> strides, ArrayRef<OpFoldResult> indices = {});
50+
51+
/// For a `memref` with `offset` and `sizes`, returns the
52+
/// offset and size to use for the linearized `memref`, assuming that
53+
/// the strides are computed from a row-major ordering of the sizes;
54+
/// - If the linearization is done for emulating load/stores of
55+
/// element type with bitwidth `srcBits` using element type with
56+
/// bitwidth `dstBits`, the linearized offset and size are
57+
/// scaled down by `dstBits`/`srcBits`.
58+
LinearizedMemRefInfo
59+
getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits,
60+
int dstBits, OpFoldResult offset,
61+
ArrayRef<OpFoldResult> sizes);
6162

6263
} // namespace memref
6364
} // namespace mlir

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 121 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,18 @@ using namespace mlir;
3535
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
3636
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
3737
/// element has 4 bits.
38-
static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
39-
int targetBits, OpBuilder &builder) {
38+
static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
39+
int sourceBits, int targetBits,
40+
OpBuilder &builder) {
4041
assert(targetBits % sourceBits == 0);
41-
IntegerType targetType = builder.getIntegerType(targetBits);
42-
IntegerAttr idxAttr =
43-
builder.getIntegerAttr(targetType, targetBits / sourceBits);
44-
auto idx = builder.create<arith::ConstantOp>(loc, targetType, idxAttr);
45-
IntegerAttr srcBitsAttr = builder.getIntegerAttr(targetType, sourceBits);
46-
auto srcBitsValue =
47-
builder.create<arith::ConstantOp>(loc, targetType, srcBitsAttr);
48-
auto m = builder.create<arith::RemUIOp>(loc, srcIdx, idx);
49-
return builder.create<arith::MulIOp>(loc, targetType, m, srcBitsValue);
42+
AffineExpr s0;
43+
bindSymbols(builder.getContext(), s0);
44+
int scaleFactor = targetBits / sourceBits;
45+
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
46+
builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
47+
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
48+
IntegerType dstType = builder.getIntegerType(targetBits);
49+
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
5050
}
5151

5252
namespace {
@@ -61,15 +61,43 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
6161
LogicalResult
6262
matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
6363
ConversionPatternRewriter &rewriter) const override {
64-
Type newTy = getTypeConverter()->convertType(op.getType());
65-
if (!newTy) {
64+
auto currentType = op.getMemref().getType().cast<MemRefType>();
65+
auto newResultType =
66+
getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
67+
if (!newResultType) {
6668
return rewriter.notifyMatchFailure(
6769
op->getLoc(),
6870
llvm::formatv("failed to convert memref type: {0}", op.getType()));
6971
}
7072

73+
// Special case zero-rank memrefs.
74+
if (currentType.getRank() == 0) {
75+
rewriter.replaceOpWithNewOp<memref::AllocOp>(
76+
op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
77+
adaptor.getAlignmentAttr());
78+
return success();
79+
}
80+
81+
Location loc = op.getLoc();
82+
OpFoldResult zero = rewriter.getIndexAttr(0);
83+
SmallVector<OpFoldResult> indices(currentType.getRank(), zero);
84+
85+
// Get linearized type.
86+
int srcBits = currentType.getElementType().getIntOrFloatBitWidth();
87+
int dstBits = newResultType.getElementType().getIntOrFloatBitWidth();
88+
SmallVector<OpFoldResult> sizes = op.getMixedSizes();
89+
90+
memref::LinearizedMemRefInfo linearizedMemRefInfo =
91+
memref::getLinearizedMemRefOffsetAndSize(
92+
rewriter, loc, srcBits, dstBits, /*offset =*/zero, sizes);
93+
SmallVector<Value> dynamicLinearizedSize;
94+
if (!newResultType.hasStaticShape()) {
95+
dynamicLinearizedSize.push_back(getValueOrCreateConstantIndexOp(
96+
rewriter, loc, linearizedMemRefInfo.linearizedSize));
97+
}
98+
7199
rewriter.replaceOpWithNewOp<memref::AllocOp>(
72-
op, newTy, adaptor.getDynamicSizes(), adaptor.getSymbolOperands(),
100+
op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
73101
adaptor.getAlignmentAttr());
74102
return success();
75103
}
@@ -109,73 +137,68 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
109137
LogicalResult
110138
matchAndRewrite(memref::LoadOp op, OpAdaptor adaptor,
111139
ConversionPatternRewriter &rewriter) const override {
112-
Type newTy = getTypeConverter()->convertType(op.getMemRefType());
113-
if (!newTy) {
114-
return rewriter.notifyMatchFailure(
115-
op->getLoc(), llvm::formatv("failed to convert memref type: {0}",
116-
op.getMemRefType()));
117-
}
118-
119-
if (op.getMemRefType() == newTy)
120-
return failure();
121-
122-
auto loc = op.getLoc();
123-
auto sourceType = cast<MemRefType>(adaptor.getMemref().getType());
124-
unsigned sourceRank = sourceType.getRank();
125-
SmallVector<Value> indices = adaptor.getIndices();
126-
assert(indices.size() == sourceRank);
127-
128-
auto srcElementType = sourceType.getElementType();
140+
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
141+
auto convertedElementType = convertedType.getElementType();
129142
auto oldElementType = op.getMemRefType().getElementType();
130143
int srcBits = oldElementType.getIntOrFloatBitWidth();
131-
int dstBits = srcElementType.getIntOrFloatBitWidth();
144+
int dstBits = convertedElementType.getIntOrFloatBitWidth();
132145
if (dstBits % srcBits != 0) {
133146
return rewriter.notifyMatchFailure(
134147
op, "only dstBits % srcBits == 0 supported");
135148
}
136149

137-
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
138-
loc, adaptor.getMemref());
139-
140-
Value newLoad, lastIdx;
141-
if (sourceRank == 0) {
142-
newLoad = rewriter.create<memref::LoadOp>(
143-
loc, srcElementType, adaptor.getMemref(), adaptor.getIndices());
144-
145-
lastIdx = stridedMetadata.getOffset();
150+
Location loc = op.getLoc();
151+
// Special case 0-rank memref loads.
152+
Value bitsLoad;
153+
if (convertedType.getRank() == 0) {
154+
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
155+
ValueRange{});
146156
} else {
147-
auto [reinterpret, linearizedOffset] =
148-
memref::getLinearizeMemRefAndOffset(loc, sourceType, srcBits, dstBits,
149-
adaptor.getIndices(),
150-
stridedMetadata, rewriter);
151-
152-
newLoad = rewriter.create<memref::LoadOp>(loc, srcElementType,
153-
reinterpret, linearizedOffset);
154-
155-
lastIdx = adaptor.getIndices().back();
157+
SmallVector<OpFoldResult> indices =
158+
getAsOpFoldResult(adaptor.getIndices());
159+
160+
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
161+
loc, op.getMemRef());
162+
163+
// Linearize the indices of the original load instruction. Do not account
164+
// for the scaling yet. This will be accounted for later.
165+
OpFoldResult linearizedIndices;
166+
std::tie(std::ignore, linearizedIndices) =
167+
memref::getLinearizedMemRefOffsetAndSize(
168+
rewriter, loc, srcBits, srcBits,
169+
stridedMetadata.getConstifiedMixedOffset(),
170+
stridedMetadata.getConstifiedMixedSizes(),
171+
stridedMetadata.getConstifiedMixedStrides(), indices);
172+
173+
AffineExpr s0;
174+
bindSymbols(rewriter.getContext(), s0);
175+
int64_t scaler = dstBits / srcBits;
176+
OpFoldResult scaledLinearizedIndices =
177+
affine::makeComposedFoldedAffineApply(
178+
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
179+
Value newLoad = rewriter.create<memref::LoadOp>(
180+
loc, adaptor.getMemref(),
181+
getValueOrCreateConstantIndexOp(rewriter, loc,
182+
scaledLinearizedIndices));
183+
184+
// Get the offset and shift the bits to the rightmost.
185+
// Note, currently only the big-endian is supported.
186+
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices,
187+
srcBits, dstBits, rewriter);
188+
bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset);
156189
}
157190

158-
// Get the offset and shift the bits to the rightmost.
159-
// Note, currently only the big-endian is supported.
160-
auto castLastIdx =
161-
rewriter.create<arith::IndexCastUIOp>(loc, srcElementType, lastIdx);
162-
163-
Value BitwidthOffset =
164-
getOffsetForBitwidth(loc, castLastIdx, srcBits, dstBits, rewriter);
165-
auto bitsLoad =
166-
rewriter.create<arith::ShRSIOp>(loc, newLoad, BitwidthOffset);
167-
168191
// Get the corresponding bits. If the arith computation bitwidth equals
169192
// to the emulated bitwidth, we apply a mask to extract the low bits.
170193
// It is not clear if this case actually happens in practice, but we keep
171194
// the operations just in case. Otherwise, if the arith computation bitwidth
172195
// is different from the emulated bitwidth we truncate the result.
173196
Operation *result;
174197
auto resultTy = getTypeConverter()->convertType(oldElementType);
175-
if (resultTy == srcElementType) {
198+
if (resultTy == convertedElementType) {
176199
auto mask = rewriter.create<arith::ConstantOp>(
177-
loc, srcElementType,
178-
rewriter.getIntegerAttr(srcElementType, (1 << srcBits) - 1));
200+
loc, convertedElementType,
201+
rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1));
179202

180203
result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask);
181204
} else {
@@ -200,6 +223,25 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
200223
patterns
201224
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment>(
202225
typeConverter, patterns.getContext());
226+
memref::populateResolveExtractStridedMetadataPatterns(patterns);
227+
}
228+
229+
static SmallVector<int64_t> getLinearizedShape(MemRefType ty, int srcBits,
230+
int dstBits) {
231+
if (ty.getRank() == 0)
232+
return {};
233+
234+
int64_t linearizedShape = 1;
235+
for (auto shape : ty.getShape()) {
236+
if (shape == ShapedType::kDynamic)
237+
return {ShapedType::kDynamic};
238+
linearizedShape *= shape;
239+
}
240+
int scale = dstBits / srcBits;
241+
// Scale the size to the ceilDiv(linearizedShape, scale)
242+
// to accomodate all the values.
243+
linearizedShape = (linearizedShape + scale - 1) / scale;
244+
return {linearizedShape};
203245
}
204246

205247
void memref::populateMemRefNarrowTypeEmulationConversions(
@@ -215,11 +257,26 @@ void memref::populateMemRefNarrowTypeEmulationConversions(
215257
if (width >= loadStoreWidth)
216258
return ty;
217259

260+
// Currently only handle innermost stride being 1, checking
261+
SmallVector<int64_t> strides;
262+
int64_t offset;
263+
if (failed(getStridesAndOffset(ty, strides, offset)))
264+
return std::nullopt;
265+
if (!strides.empty() && strides.back() != 1)
266+
return std::nullopt;
267+
218268
auto newElemTy = IntegerType::get(ty.getContext(), loadStoreWidth,
219269
intTy.getSignedness());
220270
if (!newElemTy)
221271
return std::nullopt;
222272

223-
return ty.cloneWith(std::nullopt, newElemTy);
273+
StridedLayoutAttr layoutAttr;
274+
if (offset != 0) {
275+
layoutAttr = StridedLayoutAttr::get(ty.getContext(), offset,
276+
ArrayRef<int64_t>{1});
277+
}
278+
279+
return MemRefType::get(getLinearizedShape(ty, width, loadStoreWidth),
280+
newElemTy, layoutAttr, ty.getMemorySpace());
224281
});
225282
}

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -687,13 +687,17 @@ struct ExtractStridedMetadataOpAllocFolder
687687

688688
auto baseBufferType = cast<MemRefType>(op.getBaseBuffer().getType());
689689
int64_t offset = 0;
690-
if (allocLikeOp.getType() == baseBufferType)
691-
results.push_back(allocLikeOp);
692-
else
693-
results.push_back(rewriter.create<memref::ReinterpretCastOp>(
694-
loc, baseBufferType, allocLikeOp, offset,
695-
/*sizes=*/ArrayRef<int64_t>(),
696-
/*strides=*/ArrayRef<int64_t>()));
690+
if (op.getBaseBuffer().use_empty()) {
691+
results.push_back(nullptr);
692+
} else {
693+
if (allocLikeOp.getType() == baseBufferType)
694+
results.push_back(allocLikeOp);
695+
else
696+
results.push_back(rewriter.create<memref::ReinterpretCastOp>(
697+
loc, baseBufferType, allocLikeOp, offset,
698+
/*sizes=*/ArrayRef<int64_t>(),
699+
/*strides=*/ArrayRef<int64_t>()));
700+
}
697701

698702
// Offset.
699703
results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset));

0 commit comments

Comments
 (0)