Skip to content

Commit 229aa00

Browse files
committed
[mlir] Add subbyte emulation support for memref.store.
This adds a conversion for narrow type emulation of memref.store ops. The conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics are used to prevent race conditions on same-byte accesses, in the event that two threads are storing into the same byte.
1 parent a90e215 commit 229aa00

File tree

2 files changed

+300
-27
lines changed

2 files changed

+300
-27
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 131 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919
#include "mlir/Dialect/Vector/IR/VectorOps.h"
20+
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/OpDefinition.h"
2123
#include "mlir/Support/LogicalResult.h"
2224
#include "mlir/Support/MathExtras.h"
2325
#include "mlir/Transforms/DialectConversion.h"
@@ -102,13 +104,64 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
102104
AffineExpr s0;
103105
bindSymbols(builder.getContext(), s0);
104106
int scaleFactor = targetBits / sourceBits;
105-
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
106-
builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
107+
AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
108+
OpFoldResult offsetVal =
109+
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
107110
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
108111
IntegerType dstType = builder.getIntegerType(targetBits);
109112
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
110113
}
111114

115+
/// When writing a subbyte size, masked bitwise operations are used to only
116+
/// modify the relevant bits. This function returns an and mask for clearing
117+
/// the destination bits in a subbyte write. E.g., when writing to the second
118+
/// i4 in an i32, 0xFFFFFF0F is created.
119+
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
120+
int64_t srcBits, int64_t dstBits,
121+
Value bitwidthOffset, OpBuilder &builder) {
122+
auto dstIntegerType = builder.getIntegerType(dstBits);
123+
auto maskRightAlignedAttr =
124+
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
125+
Value maskRightAligned = builder.create<arith::ConstantOp>(
126+
loc, dstIntegerType, maskRightAlignedAttr);
127+
Value writeMaskInverse =
128+
builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
129+
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
130+
Value flipVal =
131+
builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
132+
return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
133+
}
134+
135+
/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
136+
/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
137+
/// the returned index has the granularity of `dstBits`
138+
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
139+
OpFoldResult linearizedIndex,
140+
int64_t srcBits, int64_t dstBits) {
141+
AffineExpr s0;
142+
bindSymbols(builder.getContext(), s0);
143+
int64_t scaler = dstBits / srcBits;
144+
OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
145+
builder, loc, s0.floorDiv(scaler), {linearizedIndex});
146+
return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
147+
}
148+
149+
static OpFoldResult
150+
getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
151+
const SmallVector<OpFoldResult> &indices,
152+
Value memref) {
153+
auto stridedMetadata =
154+
builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
155+
OpFoldResult linearizedIndices;
156+
std::tie(std::ignore, linearizedIndices) =
157+
memref::getLinearizedMemRefOffsetAndSize(
158+
builder, loc, srcBits, srcBits,
159+
stridedMetadata.getConstifiedMixedOffset(),
160+
stridedMetadata.getConstifiedMixedSizes(),
161+
stridedMetadata.getConstifiedMixedStrides(), indices);
162+
return linearizedIndices;
163+
}
164+
112165
namespace {
113166

114167
//===----------------------------------------------------------------------===//
@@ -218,32 +271,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
218271
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
219272
ValueRange{});
220273
} else {
221-
SmallVector<OpFoldResult> indices =
222-
getAsOpFoldResult(adaptor.getIndices());
223-
224-
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
225-
loc, op.getMemRef());
226-
227274
// Linearize the indices of the original load instruction. Do not account
228275
// for the scaling yet. This will be accounted for later.
229-
OpFoldResult linearizedIndices;
230-
std::tie(std::ignore, linearizedIndices) =
231-
memref::getLinearizedMemRefOffsetAndSize(
232-
rewriter, loc, srcBits, srcBits,
233-
stridedMetadata.getConstifiedMixedOffset(),
234-
stridedMetadata.getConstifiedMixedSizes(),
235-
stridedMetadata.getConstifiedMixedStrides(), indices);
236-
237-
AffineExpr s0;
238-
bindSymbols(rewriter.getContext(), s0);
239-
int64_t scaler = dstBits / srcBits;
240-
OpFoldResult scaledLinearizedIndices =
241-
affine::makeComposedFoldedAffineApply(
242-
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
276+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
277+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
278+
243279
Value newLoad = rewriter.create<memref::LoadOp>(
244280
loc, adaptor.getMemref(),
245-
getValueOrCreateConstantIndexOp(rewriter, loc,
246-
scaledLinearizedIndices));
281+
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
282+
dstBits));
247283

248284
// Get the offset and shift the bits to the rightmost.
249285
// Note, currently only the big-endian is supported.
@@ -305,6 +341,74 @@ struct ConvertMemRefReinterpretCast final
305341
}
306342
};
307343

344+
//===----------------------------------------------------------------------===//
345+
// ConvertMemrefStore
346+
//===----------------------------------------------------------------------===//
347+
348+
struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
349+
using OpConversionPattern::OpConversionPattern;
350+
351+
LogicalResult
352+
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
353+
ConversionPatternRewriter &rewriter) const override {
354+
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
355+
int srcBits = op.getMemRefType().getElementTypeBitWidth();
356+
int dstBits = convertedType.getElementTypeBitWidth();
357+
auto dstIntegerType = rewriter.getIntegerType(dstBits);
358+
if (dstBits % srcBits != 0) {
359+
return rewriter.notifyMatchFailure(
360+
op, "only dstBits % srcBits == 0 supported");
361+
}
362+
363+
Location loc = op.getLoc();
364+
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
365+
adaptor.getValue());
366+
367+
// Special case 0-rank memref stores. We compute the mask at compile time.
368+
if (convertedType.getRank() == 0) {
369+
// Create mask to clear destination bits
370+
auto writeMaskValAttr =
371+
rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
372+
Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
373+
writeMaskValAttr);
374+
375+
// Clear destination bits
376+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
377+
writeMask, adaptor.getMemref(),
378+
ValueRange{});
379+
// Write srcs bits to destination
380+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
381+
extendedInput, adaptor.getMemref(),
382+
ValueRange{});
383+
rewriter.eraseOp(op);
384+
return success();
385+
}
386+
387+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
388+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
389+
Value storeIndices = getIndicesForLoadOrStore(
390+
rewriter, loc, linearizedIndices, srcBits, dstBits);
391+
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
392+
dstBits, rewriter);
393+
Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
394+
dstBits, bitwidthOffset, rewriter);
395+
// Align the value to write with the destination bits
396+
Value alignedVal =
397+
rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
398+
399+
// Clear destination bits
400+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
401+
writeMask, adaptor.getMemref(),
402+
storeIndices);
403+
// Write srcs bits to destination
404+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
405+
alignedVal, adaptor.getMemref(),
406+
storeIndices);
407+
rewriter.eraseOp(op);
408+
return success();
409+
}
410+
};
411+
308412
//===----------------------------------------------------------------------===//
309413
// ConvertMemRefSubview
310414
//===----------------------------------------------------------------------===//
@@ -350,9 +454,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
350454
// Populate `memref.*` conversion patterns.
351455
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
352456
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
353-
ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
354-
ConvertMemRefReinterpretCast>(typeConverter,
355-
patterns.getContext());
457+
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
458+
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
459+
typeConverter, patterns.getContext());
356460
memref::populateResolveExtractStridedMetadataPatterns(patterns);
357461
}
358462

0 commit comments

Comments
 (0)