-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Add subbyte emulation support for memref.store
.
#73174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-memref Author: None (Max191) ChangesThis adds a conversion for narrow type emulation of memref.store ops. The conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics are used to prevent race conditions on same-byte accesses, in the event that two threads are storing into the same byte. Patch is 22.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/73174.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..58bcf7b9ddde552 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,9 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
@@ -29,6 +32,26 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//
+/// Replaces the memref::StoreOp with two new memref::AtomicRMWOps. The first
+/// memref::AtomicRMWOp sets the destination bits to all zero to prepare the
+/// destination byte to be written to. The second memref::AtomicRMWOp does the
+/// writing of the value to store, using an `ori` type operation. The value
+/// to store and the write mask should both have the destination type bitwidth,
+/// and the bits of the value to store should be all zero except for the bits
+/// aligned with the store destination.
+static void replaceStoreWithAtomics(ConversionPatternRewriter &rewriter,
+ memref::StoreOp op, Value writeMask,
+ Value storeVal, Value memref,
+ ValueRange storeIndices) {
+ // Clear destination bits
+ rewriter.create<memref::AtomicRMWOp>(op.getLoc(), arith::AtomicRMWKind::andi,
+ writeMask, memref, storeIndices);
+ // Write srcs bits to destination
+ rewriter.create<memref::AtomicRMWOp>(op->getLoc(), arith::AtomicRMWKind::ori,
+ storeVal, memref, storeIndices);
+ rewriter.eraseOp(op);
+}
+
/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
@@ -43,13 +66,67 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
AffineExpr s0;
bindSymbols(builder.getContext(), s0);
int scaleFactor = targetBits / sourceBits;
- OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
- builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+ AffineExpr offsetExpr = (s0 % scaleFactor) * sourceBits;
+ OpFoldResult offsetVal =
+ affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
IntegerType dstType = builder.getIntegerType(targetBits);
return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
}
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+ int64_t srcBits, int64_t dstBits,
+ Value bitwidthOffset, OpBuilder &builder) {
+ auto dstIntegerType = builder.getIntegerType(dstBits);
+ auto maskRightAlignedAttr =
+ builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+ Value maskRightAligned =
+ builder
+ .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+ .getResult();
+ Value writeMaskInverse =
+ builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+ auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+ Value flipVal =
+ builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+ .getResult();
+ return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+ OpFoldResult linearizedIndex,
+ int64_t srcBits, int64_t dstBits) {
+ AffineExpr s0;
+ bindSymbols(builder.getContext(), s0);
+ int64_t scaler = dstBits / srcBits;
+ OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+ builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+ return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+ const SmallVector<OpFoldResult> &indices,
+ Value memref) {
+ auto stridedMetadata =
+ builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+ OpFoldResult linearizedIndices;
+ std::tie(std::ignore, linearizedIndices) =
+ memref::getLinearizedMemRefOffsetAndSize(
+ builder, loc, srcBits, srcBits,
+ stridedMetadata.getConstifiedMixedOffset(),
+ stridedMetadata.getConstifiedMixedSizes(),
+ stridedMetadata.getConstifiedMixedStrides(), indices);
+ return linearizedIndices;
+}
+
namespace {
//===----------------------------------------------------------------------===//
@@ -155,32 +232,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
ValueRange{});
} else {
- SmallVector<OpFoldResult> indices =
- getAsOpFoldResult(adaptor.getIndices());
-
- auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
- loc, op.getMemRef());
-
// Linearize the indices of the original load instruction. Do not account
// for the scaling yet. This will be accounted for later.
- OpFoldResult linearizedIndices;
- std::tie(std::ignore, linearizedIndices) =
- memref::getLinearizedMemRefOffsetAndSize(
- rewriter, loc, srcBits, srcBits,
- stridedMetadata.getConstifiedMixedOffset(),
- stridedMetadata.getConstifiedMixedSizes(),
- stridedMetadata.getConstifiedMixedStrides(), indices);
-
- AffineExpr s0;
- bindSymbols(rewriter.getContext(), s0);
- int64_t scaler = dstBits / srcBits;
- OpFoldResult scaledLinearizedIndices =
- affine::makeComposedFoldedAffineApply(
- rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
Value newLoad = rewriter.create<memref::LoadOp>(
loc, adaptor.getMemref(),
- getValueOrCreateConstantIndexOp(rewriter, loc,
- scaledLinearizedIndices));
+ getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+ dstBits));
// Get the offset and shift the bits to the rightmost.
// Note, currently only the big-endian is supported.
@@ -211,6 +271,60 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+ int srcBits = op.getMemRefType().getElementTypeBitWidth();
+ int dstBits = convertedType.getElementTypeBitWidth();
+ auto dstIntegerType = rewriter.getIntegerType(dstBits);
+ if (dstBits % srcBits != 0) {
+ return rewriter.notifyMatchFailure(
+ op, "only dstBits % srcBits == 0 supported");
+ }
+
+ Location loc = op.getLoc();
+ Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType,
+ adaptor.getValue());
+
+ // Special case 0-rank memref stores. We can compute the mask at compile
+ // time.
+ if (convertedType.getRank() == 0) {
+ // Create mask to clear destination bits
+ auto writeMaskValAttr =
+ rewriter.getIntegerAttr(dstIntegerType, ~(1 << (srcBits)) - 1);
+ Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType,
+ writeMaskValAttr);
+
+ replaceStoreWithAtomics(rewriter, op, writeMask, extendedInput,
+ adaptor.getMemref(), ValueRange{});
+ return success();
+ }
+
+ OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+ rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+ Value storeIndices = getIndicesForLoadOrStore(
+ rewriter, loc, linearizedIndices, srcBits, dstBits);
+ Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+ dstBits, rewriter);
+ Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
+ dstBits, bitwidthOffset, rewriter);
+ // Align the value to write with the destination bits
+ Value alignedVal =
+ rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset);
+ replaceStoreWithAtomics(rewriter, op, writeMask, alignedVal,
+ adaptor.getMemref(), storeIndices);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
@@ -291,9 +405,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {
// Populate `memref.*` conversion patterns.
- patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
- ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
+ ConvertMemRefSubview, ConvertMemrefStore>(typeConverter,
+ patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..22c5947fd2ac97b 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,172 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+ %0 = memref.alloc() : memref<5xi4>
+ memref.store %arg1, %0[%arg0] : memref<5xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
+// CHECK: func @memref_store_i4(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<1xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @memref_store_i4_rank2(%arg0: index, %arg1: index, %arg2: i4) -> () {
+ %0 = memref.alloc() : memref<3x125xi4>
+ memref.assume_alignment %0, 64 : memref<3x125xi4>
+ memref.store %arg2, %0[%arg0,%arg1] : memref<3x125xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 2) * 8)>
+// CHECK: func @memref_store_i4_rank2(
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<188xi8>
+// CHECK-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<188xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<188xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * 125 + s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 500 + s1 * 4 - ((s0 * 125 + s1) floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4_rank2(
+// CHECK32-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i4
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc() : memref<47xi32>
+// CHECK32-DAG: memref.assume_alignment %[[ALLOC]], 64 : memref<47xi32>
+// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG2]] : i4 to i32
+// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+// CHECK32-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i32
+// CHECK32-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I32]] : i32
+// CHECK32-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i32
+// CHECK32-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i32
+// CHECK32-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I32]] : i32
+// CHECK32: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+// CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i32, memref<47xi32>) -> i32
+// CHECK32: return
+
+// -----
+
+func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: i4) -> () {
+ %0 = memref.alloc(%arg0, %arg1) : memref<?x?xi4>
+ memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
+ return
+}
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
+// CHECK: func @memref_store_i4_dynamic(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
+// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
+// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
+// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[BITOFFSET:.+]] = affine.apply #[[MAP2]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
+// CHECK-DAG: %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+// CHECK-DAG: %[[MASK_BASE:.+]] = arith.constant 15 : i8
+// CHECK-DAG: %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+// CHECK-DAG: %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+// CHECK-DAG: %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+// CHECK-DAG: %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+// CHECK: %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
+// CHECK: return
+
+// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
+// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
+// CHECK32: func @memref_store_i4_dynamic(
+// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
+// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
+// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
+// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
+// CHECK3...
[truncated]
|
54798b9
to
67b7336
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, overall looks good! Just two nits and one major comment, please take a look.
This adds a conversion for narrow type emulation of memref.store ops. The conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics are used to prevent race conditions on same-byte accesses, in the event that two threads are storing into the same byte.
67b7336
to
229aa00
Compare
This adds a conversion for narrow type emulation of memref.store ops. The conversion replaces the memref.store with two memref.atomic_rmw ops. Atomics are used to prevent race conditions on same-byte accesses, in the event that two threads are storing into the same byte.
Fixes iree-org/iree#15370