Skip to content

Commit 3a6f02a

Browse files
authored
[mlir] Add subbyte emulation support for memref.store. (#73174)
This adds a conversion for narrow type emulation of memref.store ops. The conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics are used to prevent race conditions on same-byte accesses, in the event that two threads are storing into the same byte. Fixes iree-org/iree#15370
1 parent e88a1ce commit 3a6f02a

File tree

2 files changed

+285
-27
lines changed

2 files changed

+285
-27
lines changed

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 120 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
1818
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1919
#include "mlir/Dialect/Vector/IR/VectorOps.h"
20+
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/BuiltinTypes.h"
22+
#include "mlir/IR/OpDefinition.h"
2123
#include "mlir/Support/LogicalResult.h"
2224
#include "mlir/Support/MathExtras.h"
2325
#include "mlir/Transforms/DialectConversion.h"
@@ -102,13 +104,64 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
102104
AffineExpr s0;
103105
bindSymbols(builder.getContext(), s0);
104106
int scaleFactor = targetBits / sourceBits;
105-
OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
106-
builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
107+
AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
108+
OpFoldResult offsetVal =
109+
affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
107110
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
108111
IntegerType dstType = builder.getIntegerType(targetBits);
109112
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
110113
}
111114

115+
/// When writing a subbyte size, masked bitwise operations are used to only
116+
/// modify the relevant bits. This function returns an and mask for clearing
117+
/// the destination bits in a subbyte write. E.g., when writing to the second
118+
/// i4 in an i32, 0xFFFFFF0F is created.
119+
static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices,
120+
int64_t srcBits, int64_t dstBits,
121+
Value bitwidthOffset, OpBuilder &builder) {
122+
auto dstIntegerType = builder.getIntegerType(dstBits);
123+
auto maskRightAlignedAttr =
124+
builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
125+
Value maskRightAligned = builder.create<arith::ConstantOp>(
126+
loc, dstIntegerType, maskRightAlignedAttr);
127+
Value writeMaskInverse =
128+
builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
129+
auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
130+
Value flipVal =
131+
builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr);
132+
return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
133+
}
134+
135+
/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
136+
/// sizes. The input `linearizedIndex` has the granularity of `srcBits`, and
137+
/// the returned index has the granularity of `dstBits`
138+
static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
139+
OpFoldResult linearizedIndex,
140+
int64_t srcBits, int64_t dstBits) {
141+
AffineExpr s0;
142+
bindSymbols(builder.getContext(), s0);
143+
int64_t scaler = dstBits / srcBits;
144+
OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
145+
builder, loc, s0.floorDiv(scaler), {linearizedIndex});
146+
return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
147+
}
148+
149+
static OpFoldResult
150+
getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
151+
const SmallVector<OpFoldResult> &indices,
152+
Value memref) {
153+
auto stridedMetadata =
154+
builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
155+
OpFoldResult linearizedIndices;
156+
std::tie(std::ignore, linearizedIndices) =
157+
memref::getLinearizedMemRefOffsetAndSize(
158+
builder, loc, srcBits, srcBits,
159+
stridedMetadata.getConstifiedMixedOffset(),
160+
stridedMetadata.getConstifiedMixedSizes(),
161+
stridedMetadata.getConstifiedMixedStrides(), indices);
162+
return linearizedIndices;
163+
}
164+
112165
namespace {
113166

114167
//===----------------------------------------------------------------------===//
@@ -218,32 +271,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
218271
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
219272
ValueRange{});
220273
} else {
221-
SmallVector<OpFoldResult> indices =
222-
getAsOpFoldResult(adaptor.getIndices());
223-
224-
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
225-
loc, op.getMemRef());
226-
227274
// Linearize the indices of the original load instruction. Do not account
228275
// for the scaling yet. This will be accounted for later.
229-
OpFoldResult linearizedIndices;
230-
std::tie(std::ignore, linearizedIndices) =
231-
memref::getLinearizedMemRefOffsetAndSize(
232-
rewriter, loc, srcBits, srcBits,
233-
stridedMetadata.getConstifiedMixedOffset(),
234-
stridedMetadata.getConstifiedMixedSizes(),
235-
stridedMetadata.getConstifiedMixedStrides(), indices);
236-
237-
AffineExpr s0;
238-
bindSymbols(rewriter.getContext(), s0);
239-
int64_t scaler = dstBits / srcBits;
240-
OpFoldResult scaledLinearizedIndices =
241-
affine::makeComposedFoldedAffineApply(
242-
rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
276+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
277+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
278+
243279
Value newLoad = rewriter.create<memref::LoadOp>(
244280
loc, adaptor.getMemref(),
245-
getValueOrCreateConstantIndexOp(rewriter, loc,
246-
scaledLinearizedIndices));
281+
getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
282+
dstBits));
247283

248284
// Get the offset and shift the bits to the rightmost.
249285
// Note, currently only the big-endian is supported.
@@ -305,6 +341,63 @@ struct ConvertMemRefReinterpretCast final
305341
}
306342
};
307343

344+
//===----------------------------------------------------------------------===//
345+
// ConvertMemrefStore
346+
//===----------------------------------------------------------------------===//
347+
348+
struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
349+
using OpConversionPattern::OpConversionPattern;
350+
351+
LogicalResult
352+
matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
353+
ConversionPatternRewriter &rewriter) const override {
354+
auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
355+
int srcBits = op.getMemRefType().getElementTypeBitWidth();
356+
int dstBits = convertedType.getElementTypeBitWidth();
357+
auto dstIntegerType = rewriter.getIntegerType(dstBits);
358+
if (dstBits % srcBits != 0) {
359+
return rewriter.notifyMatchFailure(
360+
op, "only dstBits % srcBits == 0 supported");
361+
}
362+
363+
Location loc = op.getLoc();
364+
Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
365+
adaptor.getValue());
366+
367+
// Special case 0-rank memref stores. No need for masking.
368+
if (convertedType.getRank() == 0) {
369+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign,
370+
extendedInput, adaptor.getMemref(),
371+
ValueRange{});
372+
rewriter.eraseOp(op);
373+
return success();
374+
}
375+
376+
OpFoldResult linearizedIndices = getLinearizedSrcIndices(
377+
rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
378+
Value storeIndices = getIndicesForLoadOrStore(
379+
rewriter, loc, linearizedIndices, srcBits, dstBits);
380+
Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
381+
dstBits, rewriter);
382+
Value writeMask = getSubByteWriteMask(loc, linearizedIndices, srcBits,
383+
dstBits, bitwidthOffset, rewriter);
384+
// Align the value to write with the destination bits
385+
Value alignedVal =
386+
rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
387+
388+
// Clear destination bits
389+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
390+
writeMask, adaptor.getMemref(),
391+
storeIndices);
392+
// Write srcs bits to destination
393+
rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
394+
alignedVal, adaptor.getMemref(),
395+
storeIndices);
396+
rewriter.eraseOp(op);
397+
return success();
398+
}
399+
};
400+
308401
//===----------------------------------------------------------------------===//
309402
// ConvertMemRefSubview
310403
//===----------------------------------------------------------------------===//
@@ -350,9 +443,9 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
350443
// Populate `memref.*` conversion patterns.
351444
patterns.add<ConvertMemRefAllocation<memref::AllocOp>,
352445
ConvertMemRefAllocation<memref::AllocaOp>, ConvertMemRefLoad,
353-
ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
354-
ConvertMemRefReinterpretCast>(typeConverter,
355-
patterns.getContext());
446+
ConvertMemrefStore, ConvertMemRefAssumeAlignment,
447+
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
448+
typeConverter, patterns.getContext());
356449
memref::populateResolveExtractStridedMetadataPatterns(patterns);
357450
}
358451

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,168 @@ func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
265265
// CHECK32: %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
266266
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
267267
// CHECK32: return %[[TRUNC]]
268+
269+
// -----
270+
271+
func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
272+
%0 = memref.alloc() : memref<5xi4>
273+
memref.store %arg1, %0[%arg0] : memref<5xi4>
274+
return
275+
}
276+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
277+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
278+
// CHECK: func @memref_store_i4(
279+
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
280+
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
281+
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
282+
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
283+
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
284+
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
285+
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
286+
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
287+
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
288+
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
289+
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
290+
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
291+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
292+
// CHECK: return
293+
294+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
295+
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
296+
// CHECK32: func @memref_store_i4(
297+
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
298+
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
299+
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
300+
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
301+
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
302+
// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
303+
// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
304+
// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
305+
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
306+
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
307+
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
308+
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
309+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
310+
// CHECK32: return
311+
312+
// -----
313+
314+
func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
315+
%0 = memref.alloc() : memref<3x125xi4>
316+
memref.assume_alignment %0, 64 : memref<3x125xi4>
317+
memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
318+
return
319+
}
320+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
321+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
322+
// CHECK: func @memref_store_i4_rank2(
323+
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
324+
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
325+
// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
326+
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
327+
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
328+
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
329+
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
330+
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
331+
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
332+
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
333+
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
334+
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
335+
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
336+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
337+
// CHECK: return
338+
339+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
340+
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
341+
// CHECK32: func @memref_store_i4_rank2(
342+
// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
343+
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
344+
// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
345+
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
346+
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
347+
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
348+
// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
349+
// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
350+
// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
351+
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
352+
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
353+
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
354+
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
355+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
356+
// CHECK32: return
357+
358+
// -----
359+
360+
func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
361+
%0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
362+
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
363+
return
364+
}
365+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
366+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
367+
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
368+
// CHECK: func @memref_store_i4_dynamic(
369+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
370+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
371+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
372+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
373+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
374+
// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
375+
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
376+
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
377+
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
378+
// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
379+
// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
380+
// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
381+
// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
382+
// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
383+
// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
384+
// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
385+
// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
386+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
387+
// CHECK: return
388+
389+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
390+
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
391+
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
392+
// CHECK32: func @memref_store_i4_dynamic(
393+
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
394+
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
395+
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
396+
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
397+
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
398+
// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
399+
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
400+
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
401+
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
402+
// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
403+
// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
404+
// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
405+
// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
406+
// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
407+
// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
408+
// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
409+
// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
410+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<?xi32>) -> i32
411+
// CHECK32: return
412+
413+
// -----
414+
415+
func.func @rank_zero_memref_store(%arg0: i4) -> () {
416+
%0 = memref.alloc() : memref<i4>
417+
memref.store %arg0, %0[] : memref<i4>
418+
return
419+
}
420+
// CHECK-LABEL: func @rank_zero_memref
421+
// CHECK-SAME: %[[ARG0:.+]]: i4
422+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<i8>
423+
// CHECK: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i8
424+
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i8, memref<i8>) -> i8
425+
// CHECK: return
426+
427+
// CHECK32-LABEL: func @rank_zero_memref
428+
// CHECK32-SAME: %[[ARG0:.+]]: i4
429+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<i32>
430+
// CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32
431+
// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref<i32>) -> i32
432+
// CHECK32: return

0 commit comments

Comments
 (0)