Skip to content

[mlir][MemRef] Extend memref.subview sub-byte type emulation support. #94045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2024

Conversation

hanhanW
Copy link
Contributor

@hanhanW hanhanW commented May 31, 2024

In some cases (see iree-org/iree#16285), memref.subview ops can't be folded into transfer ops and sub-byte type emulation fails. This issue has been blocking a few things, including the enablement of vector flattening transformations (iree-org/iree#16456). This PR extends the existing sub-byte type emulation support of memref.subview to handle multi-dimensional subviews with dynamic offsets and addresses the issues for some of the memref.subview cases that can't be folded.

Co-authored-by: Diego Caballero [email protected]

hanhanW added 2 commits May 31, 2024 13:27
In some cases (see iree-org/iree#16285), `memref.subview` ops can't be folded into transfer ops and sub-byte type emulation fails. This issue has been blocking a few things, including the enablement of vector flattening transformations (iree-org/iree#16456). This PR extends the existing sub-byte type emulation support of `memref.subview` to handle multi-dimensional subviews with dynamic offsets and addresses the issues for some of the `memref.subview` cases that can't be folded.
@llvmbot
Copy link
Member

llvmbot commented May 31, 2024

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

In some cases (see iree-org/iree#16285), memref.subview ops can't be folded into transfer ops and sub-byte type emulation fails. This issue has been blocking a few things, including the enablement of vector flattening transformations (iree-org/iree#16456). This PR extends the existing sub-byte type emulation support of memref.subview to handle multi-dimensional subviews with dynamic offsets and addresses the issues for some of the memref.subview cases that can't be folded.

Co-authored-by: Diego Caballero <[email protected]>


Full diff: https://github.com/llvm/llvm-project/pull/94045.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+97-68)
  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+1-3)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+37-2)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 77c108aab4807..bfe97672aaf8b 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -32,62 +32,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
-/// type. The result MemRefType of the old op must have a rank and stride of 1,
-/// with static offset and size. The number of bits in the offset must evenly
-/// divide the bitwidth of the new converted type.
-template <typename MemRefOpTy>
-static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
-                                      typename MemRefOpTy::Adaptor adaptor,
-                                      MemRefOpTy op, MemRefType newTy) {
-  static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
-                    std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
-                "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
-
-  auto convertedElementType = newTy.getElementType();
-  auto oldElementType = op.getType().getElementType();
-  int srcBits = oldElementType.getIntOrFloatBitWidth();
-  int dstBits = convertedElementType.getIntOrFloatBitWidth();
-  if (dstBits % srcBits != 0) {
-    return rewriter.notifyMatchFailure(op,
-                                       "only dstBits % srcBits == 0 supported");
-  }
-
-  // Only support stride of 1.
-  if (llvm::any_of(op.getStaticStrides(),
-                   [](int64_t stride) { return stride != 1; })) {
-    return rewriter.notifyMatchFailure(op->getLoc(),
-                                       "stride != 1 is not supported");
-  }
-
-  auto sizes = op.getStaticSizes();
-  int64_t offset = op.getStaticOffset(0);
-  // Only support static sizes and offsets.
-  if (llvm::any_of(sizes,
-                   [](int64_t size) { return size == ShapedType::kDynamic; }) ||
-      offset == ShapedType::kDynamic) {
-    return rewriter.notifyMatchFailure(
-        op->getLoc(), "dynamic size or offset is not supported");
-  }
-
-  int elementsPerByte = dstBits / srcBits;
-  if (offset % elementsPerByte != 0) {
-    return rewriter.notifyMatchFailure(
-        op->getLoc(), "offset not multiple of elementsPerByte is not "
-                      "supported");
-  }
-
-  SmallVector<int64_t> size;
-  if (sizes.size())
-    size.push_back(ceilDiv(sizes[0], elementsPerByte));
-  offset = offset / elementsPerByte;
-
-  rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
-                                          *adaptor.getODSOperands(0).begin(),
-                                          offset, size, op.getStaticStrides());
-  return success();
-}
-
 /// When data is loaded/stored in `targetBits` granularity, but is used in
 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
 /// treated as an array of elements of width `sourceBits`.
@@ -335,7 +279,48 @@ struct ConvertMemRefReinterpretCast final
           op->getLoc(), "subview with rank > 1 is not supported");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = op.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    // Only support stride of 1.
+    if (llvm::any_of(op.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(op->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+
+    auto sizes = op.getStaticSizes();
+    int64_t offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "dynamic size or offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "offset not multiple of elementsPerByte is not "
+                        "supported");
+    }
+
+    SmallVector<int64_t> size;
+    if (sizes.size())
+      size.push_back(ceilDiv(sizes[0], elementsPerByte));
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
+        op.getStaticStrides());
+    return success();
   }
 };
 
@@ -402,29 +387,73 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 
 /// Emulating narrow ints on subview have limited support, supporting only
 /// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType newTy =
-        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    MemRefType newTy = dyn_cast<MemRefType>(
+        getTypeConverter()->convertType(subViewOp.getType()));
     if (!newTy) {
       return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+          subViewOp->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}",
+                        subViewOp.getType()));
+    }
+
+    Location loc = subViewOp.getLoc();
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = subViewOp.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0)
+      return rewriter.notifyMatchFailure(
+          subViewOp, "only dstBits % srcBits == 0 supported");
+
+    // Only support stride of 1.
+    if (llvm::any_of(subViewOp.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+                                         "stride != 1 is not supported");
     }
 
-    // Only support offset for 1-D subview.
-    if (op.getType().getRank() != 1) {
+    if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
       return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with rank > 1 is not supported");
+          subViewOp, "the result memref type is not contiguous");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    auto sizes = subViewOp.getStaticSizes();
+    int64_t lastOffset = subViewOp.getStaticOffsets().back();
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        lastOffset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          subViewOp->getLoc(), "dynamic size or offset is not supported");
+    }
+
+    // Transform the offsets, sizes and strides according to the emulation.
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, subViewOp.getViewSource());
+
+    OpFoldResult linearizedIndices;
+    auto strides = stridedMetadata.getConstifiedMixedStrides();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            subViewOp.getMixedSizes(), strides,
+            getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+                           rewriter));
+
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+        subViewOp, newTy, adaptor.getSource(), linearizedIndices,
+        linearizedInfo.linearizedSize, strides.back());
+    return success();
   }
 };
 
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 05d5ca2ce12f4..68edd45448ee5 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -68,7 +68,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   AffineExpr mulMap = builder.getAffineConstantExpr(1);
 
   SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
-  SmallVector<OpFoldResult> sizeValues(sourceRank);
 
   for (unsigned i = 0; i < sourceRank; ++i) {
     unsigned offsetIdx = 2 * i;
@@ -79,8 +78,7 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     mulMap = mulMap * symbols[i];
   }
 
-  // Adjust linearizedIndices, size and offset by the scale factor (dstBits /
-  // srcBits).
+  // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
   addMulMap = addMulMap.floorDiv(scaler);
   mulMap = mulMap.floorDiv(scaler);
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index 435dcc944778d..a67237b5e4dd1 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
-// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
+// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32
 
 // Expect no conversions.
 func.func @memref_i8() -> i8 {
@@ -177,6 +177,41 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 
 // -----
 
+func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
+  %c0 = arith.constant 0 : index
+  %arr = memref.alloc() : memref<512x64x8x16xi4>
+  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
+                                                                            to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+  return %ld : i4
+}
+
+// CHECK-LABEL:   func.func @memref_subview_dynamic_offset_i4(
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
+// CHECK:           %[[IDX:.*]] = affine.apply
+// CHECK:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
+// CHECK:           memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL:   func.func @memref_subview_dynamic_offset_i4(
+// CHECK32:           %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
+// CHECK32:           %[[IDX:.*]] = affine.apply
+// CHECK32:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
+// CHECK32:           memref.load %[[SUBVIEW]]
+
+// -----
+
+
+func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
+  %c0 = arith.constant 0 : index
+  %arr = memref.alloc() : memref<40x40xi4>
+  // expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
+  %subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
+  %ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
+  return %ld : i4
+}
+
+// -----
+
 func.func @reinterpret_cast_memref_load_0D() -> i4 {
     %0 = memref.alloc() : memref<5xi4>
     %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>

@hanhanW
Copy link
Contributor Author

hanhanW commented May 31, 2024

The first commit is cherry-picked from #89488 and the second commit addresses the issues from #89488 (comment)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looks ok to me. I can stamp after the inling is removed so that I can verify the changes better.

@@ -335,7 +279,48 @@ struct ConvertMemRefReinterpretCast final
op->getLoc(), "subview with rank > 1 is not supported");
}

return convertCastingOp(rewriter, adaptor, op, newTy);
Type convertedElementType = newTy.getElementType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you avoid this inlining? Will help me understand the change better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, done. Now it is only used by the ReinterpretCastOp pattern, so I remove the template.

@hanhanW hanhanW requested a review from MaheshRavishankar June 3, 2024 21:54
@hanhanW hanhanW merged commit e3c9c82 into llvm:main Jun 4, 2024
7 checks passed
@hanhanW hanhanW deleted the subview-emulation-fork-from-89488 branch June 4, 2024 05:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants