Skip to content

[mlir] Add narrow type emulation conversions #72181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

Max191
Copy link
Contributor

@Max191 Max191 commented Nov 14, 2023

Adds narrow type emulation support for:
- memref.alloca
- memref.store
- memref.reinterpret_cast

Fixes iree-org/iree#15370

@llvmbot
Copy link
Member

llvmbot commented Nov 14, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-memref

Author: None (Max191)

Changes

Adds narrow type emulation support for:
- memref.alloca
- memref.store
- memref.reinterpret_cast


Patch is 30.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72181.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+218-40)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+228)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 9f58e9055acadbb..078df55e351db96 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -35,36 +36,98 @@ using namespace mlir;
 /// Return the bit offset of the value at position `srcIdx`. For example, if
 /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
 /// located at (x % 2) * 4. Because there are two elements in one i8, and one
-/// element has 4 bits.
+/// element has 4 bits. If `rightOffset` is true, return the offset from the
+/// right side of the `dstBits` container instead of the left side.
 static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx,
                                   int sourceBits, int targetBits,
-                                  OpBuilder &builder) {
+                                  OpBuilder &builder,
+                                  bool rightOffset = false) {
   assert(targetBits % sourceBits == 0);
   AffineExpr s0;
   bindSymbols(builder.getContext(), s0);
   int scaleFactor = targetBits / sourceBits;
-  OpFoldResult offsetVal = affine::makeComposedFoldedAffineApply(
-      builder, loc, (s0 % scaleFactor) * sourceBits, {srcIdx});
+  AffineExpr offsetExpr =
+      rightOffset ? (scaleFactor - 1 - s0 % scaleFactor) * sourceBits
+                  : (s0 % scaleFactor) * sourceBits;
+  OpFoldResult offsetVal =
+      affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx});
   Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal);
   IntegerType dstType = builder.getIntegerType(targetBits);
   return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset);
 }
 
+/// When writing a subbyte size, writing needs to happen atomically in case of
+/// another write happening on the same byte at the same time. To do the write,
+/// we first must clear `dstBits` at the `linearizedIndices` of the subbyte
+/// store. This function returns the appropriate mask for clearing these bits.
+static Value getAtomicWriteMask(Location loc, OpFoldResult linearizedIndices,
+                                int64_t srcBits, int64_t dstBits,
+                                Value bitwidthOffset, OpBuilder &builder) {
+  auto dstIntegerType = builder.getIntegerType(dstBits);
+  auto maskRightAlignedAttr =
+      builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1);
+  Value maskRightAligned =
+      builder
+          .create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
+          .getResult();
+  Value writeMaskInverse =
+      builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset);
+  auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1);
+  Value flipVal =
+      builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr)
+          .getResult();
+  return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal);
+}
+
+/// Returns the scaled linearized index based on the `srcBits` and `dstBits`
+/// sizes. The input `linearizedIndex` has the grandularity of `srcBits`, and
+/// the returned index has the granularity of `dstBits`
+static Value getIndicesForLoadOrStore(OpBuilder &builder, Location loc,
+                                      OpFoldResult linearizedIndex,
+                                      int64_t srcBits, int64_t dstBits) {
+  AffineExpr s0;
+  bindSymbols(builder.getContext(), s0);
+  int64_t scaler = dstBits / srcBits;
+  OpFoldResult scaledLinearizedIndices = affine::makeComposedFoldedAffineApply(
+      builder, loc, s0.floorDiv(scaler), {linearizedIndex});
+  return getValueOrCreateConstantIndexOp(builder, loc, scaledLinearizedIndices);
+}
+
+static OpFoldResult
+getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits,
+                        const SmallVector<OpFoldResult> &indices,
+                        Value memref) {
+  auto stridedMetadata =
+      builder.create<memref::ExtractStridedMetadataOp>(loc, memref);
+  OpFoldResult linearizedIndices;
+  std::tie(std::ignore, linearizedIndices) =
+      memref::getLinearizedMemRefOffsetAndSize(
+          builder, loc, srcBits, srcBits,
+          stridedMetadata.getConstifiedMixedOffset(),
+          stridedMetadata.getConstifiedMixedSizes(),
+          stridedMetadata.getConstifiedMixedStrides(), indices);
+  return linearizedIndices;
+}
+
 namespace {
 
 //===----------------------------------------------------------------------===//
 // ConvertMemRefAlloc
 //===----------------------------------------------------------------------===//
 
-struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
-  using OpConversionPattern::OpConversionPattern;
+template <typename OpTy>
+struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
+  using OpConversionPattern<OpTy>::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::AllocOp op, OpAdaptor adaptor,
+  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    auto currentType = op.getMemref().getType().cast<MemRefType>();
-    auto newResultType =
-        getTypeConverter()->convertType(op.getType()).dyn_cast<MemRefType>();
+    static_assert(std::is_same<OpTy, memref::AllocOp>() ||
+                      std::is_same<OpTy, memref::AllocaOp>(),
+                  "expected only memref::AllocOp or memref::AllocaOp");
+    auto currentType = cast<MemRefType>(op.getMemref().getType());
+    auto newResultType = dyn_cast<MemRefType>(
+        this->getTypeConverter()->convertType(op.getType()));
     if (!newResultType) {
       return rewriter.notifyMatchFailure(
           op->getLoc(),
@@ -73,9 +136,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
 
     // Special case zero-rank memrefs.
     if (currentType.getRank() == 0) {
-      rewriter.replaceOpWithNewOp<memref::AllocOp>(
-          op, newResultType, ValueRange{}, adaptor.getSymbolOperands(),
-          adaptor.getAlignmentAttr());
+      rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, ValueRange{},
+                                        adaptor.getSymbolOperands(),
+                                        adaptor.getAlignmentAttr());
       return success();
     }
 
@@ -97,9 +160,9 @@ struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
           rewriter, loc, linearizedMemRefInfo.linearizedSize));
     }
 
-    rewriter.replaceOpWithNewOp<memref::AllocOp>(
-        op, newResultType, dynamicLinearizedSize, adaptor.getSymbolOperands(),
-        adaptor.getAlignmentAttr());
+    rewriter.replaceOpWithNewOp<OpTy>(op, newResultType, dynamicLinearizedSize,
+                                      adaptor.getSymbolOperands(),
+                                      adaptor.getAlignmentAttr());
     return success();
   }
 };
@@ -155,32 +218,15 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
       bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(),
                                                  ValueRange{});
     } else {
-      SmallVector<OpFoldResult> indices =
-          getAsOpFoldResult(adaptor.getIndices());
-
-      auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
-          loc, op.getMemRef());
-
       // Linearize the indices of the original load instruction. Do not account
       // for the scaling yet. This will be accounted for later.
-      OpFoldResult linearizedIndices;
-      std::tie(std::ignore, linearizedIndices) =
-          memref::getLinearizedMemRefOffsetAndSize(
-              rewriter, loc, srcBits, srcBits,
-              stridedMetadata.getConstifiedMixedOffset(),
-              stridedMetadata.getConstifiedMixedSizes(),
-              stridedMetadata.getConstifiedMixedStrides(), indices);
-
-      AffineExpr s0;
-      bindSymbols(rewriter.getContext(), s0);
-      int64_t scaler = dstBits / srcBits;
-      OpFoldResult scaledLinearizedIndices =
-          affine::makeComposedFoldedAffineApply(
-              rewriter, loc, s0.floorDiv(scaler), {linearizedIndices});
+      OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+          rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+
       Value newLoad = rewriter.create<memref::LoadOp>(
           loc, adaptor.getMemref(),
-          getValueOrCreateConstantIndexOp(rewriter, loc,
-                                          scaledLinearizedIndices));
+          getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits,
+                                   dstBits));
 
       // Get the offset and shift the bits to the rightmost.
       // Note, currently only the big-endian is supported.
@@ -211,6 +257,136 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// ConvertMemRefReinterpretCast
+//===----------------------------------------------------------------------===//
+
+///
+struct ConvertMemRefReinterpretCast final
+    : OpConversionPattern<memref::ReinterpretCastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    MemRefType newTy =
+        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    if (!newTy) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+    }
+
+    auto convertedElementType = newTy.getElementType();
+    auto oldElementType = op.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    // Only support offset for 0-D subview.
+    if (op.getType().getRank() != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with rank > 0 is not supported");
+    }
+
+    int64_t offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "subview with dynamic offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(),
+          "subview with offset not multiple of elementsPerByte is not "
+          "supported");
+    }
+
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset,
+        SmallVector<int64_t>{}, op.getStaticStrides());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertMemrefStore
+//===----------------------------------------------------------------------===//
+
+struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::StoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto convertedType = adaptor.getMemref().getType().cast<MemRefType>();
+    auto convertedElementType = convertedType.getElementType();
+    auto oldElementType = op.getMemRefType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    auto dstIntegerType = rewriter.getIntegerType(dstBits);
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    Location loc = op.getLoc();
+    Value extendedInput = rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, adaptor.getValue());
+
+    // Special case 0-rank memref stores. We can compute the mask at compile
+    // time.
+    if (convertedType.getRank() == 0) {
+      // Shift extended value to be left aligned
+      auto shiftValAttr = rewriter.getIntegerAttr(dstIntegerType, dstBits - srcBits);
+      Value shiftVal = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, shiftValAttr).getResult();
+      Value alignedVal = rewriter.create<arith::ShLIOp>(loc, extendedInput, shiftVal).getResult();
+      // Create mask to clear destination bits
+      auto writeMaskValAttr = rewriter.getIntegerAttr(
+          dstIntegerType, (1 << (dstBits - srcBits)) - 1);
+      Value writeMask = rewriter.create<arith::ConstantOp>(loc, dstIntegerType, writeMaskValAttr).getResult();
+
+      // Clear destination bits
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, writeMask, adaptor.getMemref(), ValueRange{});
+      // Write srcs bits to destination
+      rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, alignedVal, adaptor.getMemref(), ValueRange{});
+      rewriter.eraseOp(op);
+      return success();
+    }
+
+    OpFoldResult linearizedIndices = getLinearizedSrcIndices(
+        rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef());
+    Value storeIndices = getIndicesForLoadOrStore(
+        rewriter, loc, linearizedIndices, srcBits, dstBits);
+    Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits,
+                                                dstBits, rewriter, true);
+    Value writeMask = getAtomicWriteMask(loc, linearizedIndices, srcBits,
+                                         dstBits, bitwidthOffset, rewriter);
+    // Align the value to write with the destination bits
+    Value alignedVal =
+        rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset)
+            .getResult();
+
+    // Clear destination bits
+    rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi,
+                                         writeMask, adaptor.getMemref(),
+                                         storeIndices);
+    // Write srcs bits to destination
+    rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
+                                         alignedVal, adaptor.getMemref(),
+                                         storeIndices);
+
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // ConvertMemRefSubview
 //===----------------------------------------------------------------------===//
@@ -291,8 +467,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
     RewritePatternSet &patterns) {
 
   // Populate `memref.*` conversion patterns.
-  patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
-               ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
+  patterns.add<ConvertMemRefAlloc<memref::AllocOp>,
+               ConvertMemRefAlloc<memref::AllocaOp>, ConvertMemRefLoad,
+               ConvertMemRefAssumeAlignment, ConvertMemRefSubview,
+               ConvertMemrefStore, ConvertMemRefReinterpretCast>(
       typeConverter, patterns.getContext());
   memref::populateResolveExtractStridedMetadataPatterns(patterns);
 }
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 6ed97f05aa7cff2..05ec5761c8fe024 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -174,3 +174,231 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 //       CHECK32:   %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
 //       CHECK32:   %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
 //       CHECK32:   %[[LOAD:.+]] = memref.load %[[SUBVIEW]]
+
+// -----
+
+func.func @memref_alloca_load_i4(%arg0: index) -> i4 {
+    %0 = memref.alloca() : memref<5xi4>
+    %1 = memref.load %0[%arg0] : memref<5xi4>
+    return %1 : i4
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)
+//      CHECK: func @memref_alloca_load_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index
+//      CHECK:   %[[ALLOCA:.+]] = memref.alloca() : memref<3xi8>
+//      CHECK:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//      CHECK:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
+//      CHECK:   return %[[TRUNC]]
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)
+//      CHECK32: func @memref_alloca_load_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index
+//      CHECK32:   %[[ALLOCA:.+]] = memref.alloca() : memref<1xi32>
+//      CHECK32:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//      CHECK32:   %[[LOADVAL:.+]] = memref.load %[[ALLOCA]][%[[INDEX]]]
+//      CHECK32:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//      CHECK32:   %[[CAST:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//      CHECK32:   %[[SHIFTRT:.+]] = arith.shrsi %[[LOADVAL]], %[[CAST]]
+//      CHECK32:   %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i32 to i4
+//      CHECK32:   return %[[TRUNC]]
+
+// -----
+
+func.func @memref_store_i4(%arg0: index, %arg1: i4) -> () {
+    %0 = memref.alloc() : memref<5xi4>
+    memref.store %arg1, %0[%arg0] : memref<5xi4>
+    return
+}
+//  CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
+//  CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 2) * 8 + 4)>
+//      CHECK: func @memref_store_i4(
+// CHECK-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+//  CHECK-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i8
+//  CHECK-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK-DAG:   %[[BITOFFSET_I8:.+]] = arith.index_cast %[[BITOFFSET]] : index to i8
+//  CHECK-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i8
+//  CHECK-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], %[[BITOFFSET_I8]] : i8
+//  CHECK-DAG:   %[[CST_NEG_ONE:.+]] = arith.constant -1 : i8
+//  CHECK-DAG:   %[[MASK:.+]] = arith.xori %[[MASK_SHIFTED]], %[[CST_NEG_ONE]] : i8
+//  CHECK-DAG:   %[[SHIFTED_VAL:.+]] = arith.shli %[[EXTUI]], %[[BITOFFSET_I8]] : i8
+//      CHECK:   %[[CLEAR_RMW:.+]] = memref.atomic_rmw andi %[[MASK]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %alloc[%[[INDEX]]] : (i8, memref<3xi8>) -> i8
+//      CHECK:   return
+
+//  CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
+//  CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * -4 + (s0 floordiv 8) * 32 + 28)>
+//      CHECK32: func @memref_store_i4(
+// CHECK32-SAME:     %[[ARG0:.+]]: index, %[[ARG1:.+]]: i4
+//  CHECK32-DAG:   %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
+//  CHECK32-DAG:   %[[EXTUI:.+]] = arith.extui %[[ARG1]] : i4 to i32
+//  CHECK32-DAG:   %[[INDEX:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
+//  CHECK32-DAG:   %[[BITOFFSET_I32:.+]] = arith.index_cast %[[BITOFFSET]] : index to i32
+//  CHECK32-DAG:   %[[MASK_BASE:.+]] = arith.constant 15 : i32
+//  CHECK32-DAG:   %[[MASK_SHIFTED:.+]] = arith.shli %[[MASK_BASE]], ...
[truncated]

Copy link

github-actions bot commented Nov 14, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@Max191 Max191 force-pushed the subbyte-emulation-supports branch from 9901854 to bbc87f8 Compare November 14, 2023 00:54
Adds narrow type emulation support for:
    - `memref.alloca`
    - `memref.store`
    - `memref.reinterpret_cast`
@MaheshRavishankar
Copy link
Contributor

Could we split the store part out cause that is a bit more involved

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we split it to three PRs?

  • memref.store
  • memref.reinterpret_cast
  • memref.alloca

struct ConvertMemRefAlloc final : OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
template <typename OpTy>
struct ConvertMemRefAlloc final : OpConversionPattern<OpTy> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can rename it to ConvertMemRefAllocation. It is not only used by memref.alloc, but also memref.alloca.

//===----------------------------------------------------------------------===//

/// Currently there is very limited support for memref::ReinterpretCastOp
/// conversion. Only the 0 dimensional case is supported.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only 0D is supported? Are we able to emulate other cases?

Value maskRightAligned =
builder
.create<arith::ConstantOp>(loc, dstIntegerType, maskRightAlignedAttr)
.getResult();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, .getResult() is not needed when Value type is specified. same for below other codes.

rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori,
alignedVal, adaptor.getMemref(),
ValueRange{});
rewriter.eraseOp(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use replaceOp instead? That's more common in pattern-rewrite.

Comment on lines +345 to +347
// Special case 0-rank memref stores. We can compute the mask at compile
// time.
if (convertedType.getRank() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is becoming larger than I expect.. Let's create two static functions. One for 0D case, and the other for non-0D case. What do you think?

@@ -35,36 +36,98 @@ using namespace mlir;
/// Return the bit offset of the value at position `srcIdx`. For example, if
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
/// element has 4 bits.
/// element has 4 bits. If `rightOffset` is true, return the offset from the
/// right side of the `dstBits` container instead of the left side.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing. I'd rather have two methods getLeftOffset... and getRightOffset... (also maybe its worth finding something better than left and right)

@Max191
Copy link
Contributor Author

Max191 commented Nov 22, 2023

This PR has been split into:
#73144
#73174
#73138

@Max191 Max191 closed this Nov 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support sub-byte emulation for memref.store
4 participants