Skip to content

[mlir][MemRef] Extend memref.subview sub-byte type emulation support #89488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

dcaballe
Copy link
Contributor

In some cases (see iree-org/iree#16285), memref.subview ops can't be folded into transfer ops and sub-byte type emulation fails. This issue has been blocking a few things, including the enablement of vector flattening transformations (iree-org/iree#16456). This PR extends the existing sub-byte type emulation support of memref.subview to handle multi-dimensional subviews with dynamic offsets and addresses the issues for some of the memref.subview cases that can't be folded.

In some cases (see iree-org/iree#16285), `memref.subview`
ops can't be folded into transfer ops and sub-byte type emulation fails. This issue
has been blocking a few things, including the enablement of vector flattening
transformations (iree-org/iree#16456). This PR extends the
existing sub-byte type emulation support of `memref.subview` to handle multi-dimensional
subviews with dynamic offsets and addresses the issues for some of the
`memref.subview` cases that can't be folded.
@llvmbot
Copy link
Member

llvmbot commented Apr 20, 2024

@llvm/pr-subscribers-mlir-memref
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Diego Caballero (dcaballe)

Changes

In some cases (see iree-org/iree#16285), memref.subview ops can't be folded into transfer ops and sub-byte type emulation fails. This issue has been blocking a few things, including the enablement of vector flattening transformations (iree-org/iree#16456). This PR extends the existing sub-byte type emulation support of memref.subview to handle multi-dimensional subviews with dynamic offsets and addresses the issues for some of the memref.subview cases that can't be folded.


Full diff: https://github.com/llvm/llvm-project/pull/89488.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.h (+4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp (+92-70)
  • (modified) mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp (+1-3)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+7)
  • (modified) mlir/test/Dialect/MemRef/emulate-narrow-type.mlir (+23)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
index 4603953cb40fa5..ddeda1fe1c692c 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
@@ -145,6 +145,10 @@ inline bool isReductionIterator(Attribute attr) {
 /// constant operations.
 SmallVector<int64_t> getAsIntegers(ArrayRef<Value> values);
 
+/// Returns the integer numbers in `foldResult`. `foldResult` is expected to
+/// be constant operations.
+int64_t getAsInteger(OpFoldResult foldResult);
+
 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
 /// be constant operations.
 SmallVector<int64_t> getAsIntegers(ArrayRef<OpFoldResult> foldResults);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
index 8236a4c475f17c..8325da357f34c6 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
@@ -13,7 +13,6 @@
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/MemRef/Transforms/Passes.h"
 #include "mlir/Dialect/MemRef/Transforms/Transforms.h"
 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
@@ -24,7 +23,6 @@
 #include "mlir/Support/MathExtras.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "llvm/Support/FormatVariadic.h"
-#include "llvm/Support/MathExtras.h"
 #include <cassert>
 #include <type_traits>
 
@@ -34,62 +32,6 @@ using namespace mlir;
 // Utility functions
 //===----------------------------------------------------------------------===//
 
-/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
-/// type. The result MemRefType of the old op must have a rank and stride of 1,
-/// with static offset and size. The number of bits in the offset must evenly
-/// divide the bitwidth of the new converted type.
-template <typename MemRefOpTy>
-static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
-                                      typename MemRefOpTy::Adaptor adaptor,
-                                      MemRefOpTy op, MemRefType newTy) {
-  static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
-                    std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
-                "Expected only memref::SubViewOp or memref::ReinterpretCastOp");
-
-  auto convertedElementType = newTy.getElementType();
-  auto oldElementType = op.getType().getElementType();
-  int srcBits = oldElementType.getIntOrFloatBitWidth();
-  int dstBits = convertedElementType.getIntOrFloatBitWidth();
-  if (dstBits % srcBits != 0) {
-    return rewriter.notifyMatchFailure(op,
-                                       "only dstBits % srcBits == 0 supported");
-  }
-
-  // Only support stride of 1.
-  if (llvm::any_of(op.getStaticStrides(),
-                   [](int64_t stride) { return stride != 1; })) {
-    return rewriter.notifyMatchFailure(op->getLoc(),
-                                       "stride != 1 is not supported");
-  }
-
-  auto sizes = op.getStaticSizes();
-  int64_t offset = op.getStaticOffset(0);
-  // Only support static sizes and offsets.
-  if (llvm::any_of(sizes,
-                   [](int64_t size) { return size == ShapedType::kDynamic; }) ||
-      offset == ShapedType::kDynamic) {
-    return rewriter.notifyMatchFailure(
-        op->getLoc(), "dynamic size or offset is not supported");
-  }
-
-  int elementsPerByte = dstBits / srcBits;
-  if (offset % elementsPerByte != 0) {
-    return rewriter.notifyMatchFailure(
-        op->getLoc(), "offset not multiple of elementsPerByte is not "
-                      "supported");
-  }
-
-  SmallVector<int64_t> size;
-  if (sizes.size())
-    size.push_back(ceilDiv(sizes[0], elementsPerByte));
-  offset = offset / elementsPerByte;
-
-  rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
-                                          *adaptor.getODSOperands(0).begin(),
-                                          offset, size, op.getStaticStrides());
-  return success();
-}
-
 /// When data is loaded/stored in `targetBits` granularity, but is used in
 /// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
 /// treated as an array of elements of width `sourceBits`.
@@ -337,7 +279,48 @@ struct ConvertMemRefReinterpretCast final
           op->getLoc(), "subview with rank > 1 is not supported");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = op.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "only dstBits % srcBits == 0 supported");
+    }
+
+    // Only support stride of 1.
+    if (llvm::any_of(op.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(op->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+
+    auto sizes = op.getStaticSizes();
+    int64_t offset = op.getStaticOffset(0);
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        offset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "dynamic size or offset is not supported");
+    }
+
+    int elementsPerByte = dstBits / srcBits;
+    if (offset % elementsPerByte != 0) {
+      return rewriter.notifyMatchFailure(
+          op->getLoc(), "offset not multiple of elementsPerByte is not "
+                        "supported");
+    }
+
+    SmallVector<int64_t> size;
+    if (sizes.size())
+      size.push_back(ceilDiv(sizes[0], elementsPerByte));
+    offset = offset / elementsPerByte;
+
+    rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
+        op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
+        op.getStaticStrides());
+    return success();
   }
 };
 
@@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {
 
 /// Emulating narrow ints on subview have limited support, supporting only
 /// static offset and size and stride of 1. Ideally, the subview should be
-/// folded away before running narrow type emulation, and this pattern would
-/// never run. This pattern is mostly used for testing pruposes.
+/// folded away before running narrow type emulation, and this pattern should
+/// only run for cases that can't be folded.
 struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
   using OpConversionPattern::OpConversionPattern;
 
   LogicalResult
-  matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
+  matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    MemRefType newTy =
-        dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
+    MemRefType newTy = dyn_cast<MemRefType>(
+        getTypeConverter()->convertType(subViewOp.getType()));
     if (!newTy) {
       return rewriter.notifyMatchFailure(
-          op->getLoc(),
-          llvm::formatv("failed to convert memref type: {0}", op.getType()));
+          subViewOp->getLoc(),
+          llvm::formatv("failed to convert memref type: {0}",
+                        subViewOp.getType()));
     }
 
-    // Only support offset for 1-D subview.
-    if (op.getType().getRank() != 1) {
+    Location loc = subViewOp.getLoc();
+    Type convertedElementType = newTy.getElementType();
+    Type oldElementType = subViewOp.getType().getElementType();
+    int srcBits = oldElementType.getIntOrFloatBitWidth();
+    int dstBits = convertedElementType.getIntOrFloatBitWidth();
+    if (dstBits % srcBits != 0)
       return rewriter.notifyMatchFailure(
-          op->getLoc(), "subview with rank > 1 is not supported");
+          subViewOp, "only dstBits % srcBits == 0 supported");
+
+    // Only support stride of 1.
+    if (llvm::any_of(subViewOp.getStaticStrides(),
+                     [](int64_t stride) { return stride != 1; })) {
+      return rewriter.notifyMatchFailure(subViewOp->getLoc(),
+                                         "stride != 1 is not supported");
+    }
+
+    auto sizes = subViewOp.getStaticSizes();
+    int64_t lastOffset = subViewOp.getStaticOffsets().back();
+    // Only support static sizes and offsets.
+    if (llvm::any_of(
+            sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
+        lastOffset == ShapedType::kDynamic) {
+      return rewriter.notifyMatchFailure(
+          subViewOp->getLoc(), "dynamic size or offset is not supported");
     }
 
-    return convertCastingOp(rewriter, adaptor, op, newTy);
+    // Transform the offsets, sizes and strides according to the emulation.
+    auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, subViewOp.getViewSource());
+
+    OpFoldResult linearizedIndices;
+    auto strides = stridedMetadata.getConstifiedMixedStrides();
+    memref::LinearizedMemRefInfo linearizedInfo;
+    std::tie(linearizedInfo, linearizedIndices) =
+        memref::getLinearizedMemRefOffsetAndSize(
+            rewriter, loc, srcBits, dstBits,
+            stridedMetadata.getConstifiedMixedOffset(),
+            subViewOp.getMixedSizes(), strides,
+            getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
+                           rewriter));
+
+    rewriter.replaceOpWithNewOp<memref::SubViewOp>(
+        subViewOp, newTy, adaptor.getSource(), linearizedIndices,
+        linearizedInfo.linearizedSize, strides.back());
+    return success();
   }
 };
 
diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
index 556a82de2166f7..588612b55fbd41 100644
--- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
+++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
@@ -67,7 +67,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
   AffineExpr mulMap = builder.getAffineConstantExpr(1);
 
   SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
-  SmallVector<OpFoldResult> sizeValues(sourceRank);
 
   for (unsigned i = 0; i < sourceRank; ++i) {
     unsigned offsetIdx = 2 * i;
@@ -78,8 +77,7 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
     mulMap = mulMap * symbols[i];
   }
 
-  // Adjust linearizedIndices, size and offset by the scale factor (dstBits /
-  // srcBits).
+  // Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
   int64_t scaler = dstBits / srcBits;
   addMulMap = addMulMap.floorDiv(scaler);
   mulMap = mulMap.floorDiv(scaler);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 3e6425879cc67f..4f62dffa48a935 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -291,6 +291,13 @@ SmallVector<int64_t> vector::getAsIntegers(ArrayRef<Value> values) {
   return ints;
 }
 
+/// Returns the integer numbers in `foldResult`. `foldResult` is expected to
+/// be constant operations.
+int64_t vector::getAsInteger(OpFoldResult foldResult) {
+  assert(foldResult.is<Attribute>() && "Unexpected non-constant index");
+  return cast<IntegerAttr>(foldResult.get<Attribute>()).getInt();
+}
+
 /// Returns the integer numbers in `foldResults`. `foldResults` are expected to
 /// be constant operations.
 SmallVector<int64_t> vector::getAsIntegers(ArrayRef<OpFoldResult> foldResults) {
diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
index fd37b7ff0a2713..7e4002b8fff549 100644
--- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
+++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
@@ -177,6 +177,29 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
 
 // -----
 
+func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
+  %c0 = arith.constant 0 : index
+  %arr = memref.alloc() : memref<512x64x8x16xi4>
+  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
+                                                                            to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
+  return %ld : i4
+}
+
+// CHECK-LABEL:   func.func @memref_subview_dynamic_offset_i4(
+// CHECK:           %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
+// CHECK:           %[[IDX:.*]] = affine.apply
+// CHECK:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
+// CHECK:           memref.load %[[SUBVIEW]]
+
+// CHECK32-LABEL:   func.func @memref_subview_dynamic_offset_i4(
+// CHECK32:           %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
+// CHECK32:           %[[IDX:.*]] = affine.apply
+// CHECK32:           %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
+// CHECK32:           memref.load %[[SUBVIEW]]
+
+// -----
+
 func.func @reinterpret_cast_memref_load_0D() -> i4 {
     %0 = memref.alloc() : memref<5xi4>
     %reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>

getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
rewriter));

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong. You can only linearize a subiew that is "contiguous". So you have to check that the subview is contiguous in memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Lines 417-421 check that the subview has only unit strides. Is there anything else needed? I'm using strides.back() (L450) because I know that all the strides are one and therefore the stride of the new subview would be one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think just checking the strides of the subview are enough... you have to check the strides of the memref type of the result. Those strides need to be contiguous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a test to better understand what happens...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I tried with a few tests like these:

func.func @memref_subview_dynamic_offset_i4_1(%idx : index) -> i4 {                                                                                                                                                  
  %c0 = arith.constant 0 : index                                                                                                                                                                                     
  %arr = memref.alloc() : memref<512x64x8x16xi4>                                                                                                                                                                     
  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 4] [1, 1, 1, 4] : memref<512x64x8x16xi4>                                                                                                                 
                                                                            to memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>                                                                          
  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>                                                                                                      
  return %ld : i4                                                                                                                                                                                                    
}                                                                                                                                                                                                                    
                                                                                                                                                                                                                     
func.func @memref_subview_dynamic_offset_i4_2(%idx : index) -> i4 {                                                                                                                                                  
  %c0 = arith.constant 0 : index                                                                                                                                                                                     
  %arr = memref.alloc() : memref<512x64x8x16xi4>                                                                                                                                                                     
  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 4, 8, 16] [1, 16, 1, 1] : memref<512x64x8x16xi4>                                                                                                                
                                                                            to memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>                                                                         
  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>                                                                                                     
  return %ld : i4                                                                                                                                                                                                    
} 

and all of them seem to be covered with the existing rules. Can you think of any other example?

you have to check the strides of the memref type of the result. Those strides need to be contiguous?

Note that the type of the new subview, newTy is coming from the emulation converter (L398), where we check that the original memref has input strides and then we linearize the shape, also resulting in a memref with a single unit stride. Not sure what else I can check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 4] [1, 1, 1, 4] : memref<512x64x8x16xi4>                                                                                                                 
                                                                            to memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>      

The result subview here is contiguous... If you already checked for it (or if it is somehow already enforced), I might have missed it... but take a subview of this form

%subview = memref.subview %arr[%idx0, %idx] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>

Here the result type is not contiguous. You cannot represent this as a linearized type. Is this handled?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can check it using isStaticShapeAndContiguousRowMajor in this case. It ensures that the memref type is contiguous.

@hanhanW
Copy link
Contributor

hanhanW commented Jun 4, 2024

As discussed with Diego offline, I addressed the comments and landed the commit in #94045 closing the PR

@hanhanW hanhanW closed this Jun 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants