Skip to content

[mlir][vector] Move hidden function to op definition #140813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from

Conversation

newling
Copy link
Contributor

@newling newling commented May 20, 2025

This non-functional change moves the function getStridedSliceInsertionIndices introduced in #138725 to the tablegen class definitions of vector.insert_strided_slice and vector.extract_strided_slice. Alternatives considered:

  1. duplicate in .cpp where I need it currently
  2. make a non-class method exposed in VectorUtils.h

It's quite a large function so (1) doesn't seem good. My concern with (2) is that it'll get 'lost' - having it on the class makes it more likely to be reused IMO.

Question: is the testing sufficient? It's indirectly lit tested by #138725

@newling newling force-pushed the move_function_to_op_def branch from 5bd81d0 to 1cc6345 Compare May 22, 2025 19:05
@llvmbot
Copy link
Member

llvmbot commented May 22, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: James Newling (newling)

Changes

This non-functional change moves the function getStridedSliceInsertionIndices introduced in #138725 to the tablegen class definitions of vector.insert_strided_slice and vector.extract_strided_slice. Alternatives considered:

  1. duplicate in .cpp where I need it currently
  2. make a non-class method exposed in VectorUtils.h

It's quite a large function so (1) doesn't seem good. My concern with (2) is that it'll get 'lost' - having it on the class makes it more likely to be reused IMO.

Question: is the testing sufficient? It's indirectly lit tested by #138725


Full diff: https://github.com/llvm/llvm-project/pull/140813.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+4)
  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+114)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+14-112)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 3f5564541554e..481523ff10c3f 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -1114,6 +1114,8 @@ def Vector_InsertStridedSliceOp :
         return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
       });
     }
+    // \return The indices in `dest` where values are stored.
+    FailureOr<SmallVector<int64_t>> getLinearIndices();
   }];
 
   let hasFolder = 1;
@@ -1254,6 +1256,8 @@ def Vector_ExtractStridedSliceOp :
         return ::llvm::cast<IntegerAttr>(attr).getInt() != 1;
       });
     }
+    // \return The indices in `source` where values are extracted.
+    FailureOr<SmallVector<int64_t>> getLinearIndices();
   }];
   let hasCanonicalizer = 1;
   let hasFolder = 1;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 41777347975da..e800b7b7c9ff6 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3182,6 +3182,101 @@ void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                       stridesAttr);
 }
 
+/// Convert an array of attributes into a vector of integers, if possible.
+static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
+  if (!attrs)
+    return failure();
+  SmallVector<int64_t> ints;
+  ints.reserve(attrs.size());
+  for (auto attr : attrs) {
+    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+      ints.push_back(intAttr.getInt());
+    } else {
+      return failure();
+    }
+  }
+  return ints;
+}
+
+/// Consider inserting a vector of shape `small` into a vector of shape `large`,
+/// at position `offsets`: this function enumeratates all the indices in `large`
+/// that are written to. The enumeration is with row-major ordering.
+///
+/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
+/// positions written to are (1,3) and (1,4), which have linearized indices 8
+/// and 9. So [8,9] is returned.
+///
+/// The length of the returned vector is equal to the number of elements in
+/// the shape `small` (i.e. the product of dimensions of `small`).
+static SmallVector<int64_t>
+getStridedSliceInsertionIndices(ArrayRef<int64_t> small,
+                                ArrayRef<int64_t> large,
+                                ArrayRef<int64_t> offsets) {
+
+  // Example of alignment between, `large`, `small` and `offsets`:
+  //    large  =  4, 5, 6, 7, 8
+  //    small  =     1, 6, 7, 8
+  //  offsets  =  2, 3, 0
+  //
+  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
+  assert((large.size() >= small.size()) &&
+         "rank of 'large' cannot be lower than rank of 'small'");
+  assert((large.size() >= offsets.size()) &&
+         "rank of 'large' cannot be lower than the number of offsets");
+  unsigned delta = large.size() - small.size();
+  unsigned nOffsets = offsets.size();
+  auto getSmall = [&](int64_t i) -> int64_t {
+    return i >= delta ? small[i - delta] : 1;
+  };
+  auto getOffset = [&](int64_t i) -> int64_t {
+    return i < nOffsets ? offsets[i] : 0;
+  };
+
+  // Using 2 vectors of indices, at each iteration populate the updated set of
+  // indices based on the old set of indices, and the size of the small vector
+  // in the current iteration.
+  SmallVector<int64_t> indices{0};
+  int64_t stride = 1;
+  for (int i = large.size() - 1; i >= 0; --i) {
+    int64_t currentSize = indices.size();
+    int64_t smallSize = getSmall(i);
+    int64_t nextSize = currentSize * smallSize;
+    SmallVector<int64_t> nextIndices(nextSize);
+    int64_t *base = nextIndices.begin();
+    int64_t offset = getOffset(i) * stride;
+    for (int j = 0; j < smallSize; ++j) {
+      for (int k = 0; k < currentSize; ++k) {
+        base[k] = indices[k] + offset;
+      }
+      offset += stride;
+      base += currentSize;
+    }
+    stride *= large[i];
+    indices = std::move(nextIndices);
+  }
+  return indices;
+}
+
+FailureOr<SmallVector<int64_t>> InsertStridedSliceOp::getLinearIndices() {
+
+  // Stride > 1 to be considered if/when the insert_strided_slice supports it.
+  if (hasNonUnitStrides())
+    return failure();
+
+  // Only when the destination has a static size can the indices be enumerated.
+  if (getType().isScalable())
+    return failure();
+
+  // Only when the offsets are all static can the indices be enumerated.
+  FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
+  if (failed(offsets))
+    return failure();
+
+  return getStridedSliceInsertionIndices(getSourceVectorType().getShape(),
+                                         getDestVectorType().getShape(),
+                                         offsets.value());
+}
+
 // TODO: Should be moved to Tablegen ConfinedAttr attributes.
 template <typename OpType>
 static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
@@ -3638,6 +3733,25 @@ void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
                       stridesAttr);
 }
 
+FailureOr<SmallVector<int64_t>> ExtractStridedSliceOp::getLinearIndices() {
+
+  // Stride > 1 to be considered if/when extract_strided_slice supports it.
+  if (hasNonUnitStrides())
+    return failure();
+
+  // Only when the source has a static size can the indices be enumerated.
+  if (getSourceVectorType().isScalable())
+    return failure();
+
+  // Only when the offsets are all static can the indices be enumerated.
+  FailureOr<SmallVector<int64_t>> offsets = intsFromArrayAttr(getOffsets());
+  if (failed(offsets))
+    return failure();
+
+  return getStridedSliceInsertionIndices(
+      getType().getShape(), getSourceVectorType().getShape(), offsets.value());
+}
+
 LogicalResult ExtractStridedSliceOp::verify() {
   auto type = getSourceVectorType();
   auto offsets = getOffsetsAttr();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 678a88627ca82..6cf818bbd0695 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -109,90 +109,6 @@ struct LinearizeVectorizable final
   }
 };
 
-template <typename TOp>
-static bool stridesAllOne(TOp op) {
-  static_assert(
-      std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
-          std::is_same_v<TOp, vector::InsertStridedSliceOp>,
-      "expected vector.extract_strided_slice or vector.insert_strided_slice");
-  ArrayAttr strides = op.getStrides();
-  return llvm::all_of(strides, isOneInteger);
-}
-
-/// Convert an array of attributes into a vector of integers, if possible.
-static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
-  if (!attrs)
-    return failure();
-  SmallVector<int64_t> ints;
-  ints.reserve(attrs.size());
-  for (auto attr : attrs) {
-    if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
-      ints.push_back(intAttr.getInt());
-    } else {
-      return failure();
-    }
-  }
-  return ints;
-}
-
-/// Consider inserting a vector of shape `small` into a vector of shape `large`,
-/// at position `offsets`: this function enumeratates all the indices in `large`
-/// that are written to. The enumeration is with row-major ordering.
-///
-/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
-/// positions written to are (1,3) and (1,4), which have linearized indices 8
-/// and 9. So [8,9] is returned.
-///
-/// The length of the returned vector is equal to the number of elements in
-/// the shape `small` (i.e. the product of dimensions of `small`).
-SmallVector<int64_t> static getStridedSliceInsertionIndices(
-    ArrayRef<int64_t> small, ArrayRef<int64_t> large,
-    ArrayRef<int64_t> offsets) {
-
-  // Example of alignment between, `large`, `small` and `offsets`:
-  //    large  =  4, 5, 6, 7, 8
-  //    small  =     1, 6, 7, 8
-  //  offsets  =  2, 3, 0
-  //
-  // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
-  assert((large.size() >= small.size()) &&
-         "rank of 'large' cannot be lower than rank of 'small'");
-  assert((large.size() >= offsets.size()) &&
-         "rank of 'large' cannot be lower than the number of offsets");
-  unsigned delta = large.size() - small.size();
-  unsigned nOffsets = offsets.size();
-  auto getSmall = [&](int64_t i) -> int64_t {
-    return i >= delta ? small[i - delta] : 1;
-  };
-  auto getOffset = [&](int64_t i) -> int64_t {
-    return i < nOffsets ? offsets[i] : 0;
-  };
-
-  // Using 2 vectors of indices, at each iteration populate the updated set of
-  // indices based on the old set of indices, and the size of the small vector
-  // in the current iteration.
-  SmallVector<int64_t> indices{0};
-  int64_t stride = 1;
-  for (int i = large.size() - 1; i >= 0; --i) {
-    int64_t currentSize = indices.size();
-    int64_t smallSize = getSmall(i);
-    int64_t nextSize = currentSize * smallSize;
-    SmallVector<int64_t> nextIndices(nextSize);
-    int64_t *base = nextIndices.begin();
-    int64_t offset = getOffset(i) * stride;
-    for (int j = 0; j < smallSize; ++j) {
-      for (int k = 0; k < currentSize; ++k) {
-        base[k] = indices[k] + offset;
-      }
-      offset += stride;
-      base += currentSize;
-    }
-    stride *= large[i];
-    indices = std::move(nextIndices);
-  }
-  return indices;
-}
-
 /// This pattern converts a vector.extract_strided_slice operation into a
 /// vector.shuffle operation that has a rank-1 (linearized) operand and result.
 ///
@@ -231,30 +147,23 @@ struct LinearizeVectorExtractStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for extract_strided_slice allows non-1 strides).
-    if (!stridesAllOne(extractStridedSliceOp)) {
+    if (extractStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           extractStridedSliceOp,
           "extract_strided_slice with strides != 1 not supported");
     }
 
-    FailureOr<SmallVector<int64_t>> offsets =
-        intsFromArrayAttr(extractStridedSliceOp.getOffsets());
-    if (failed(offsets)) {
+    FailureOr<SmallVector<int64_t>> indices =
+        extractStridedSliceOp.getLinearIndices();
+    if (failed(indices)) {
       return rewriter.notifyMatchFailure(extractStridedSliceOp,
-                                         "failed to get integer offsets");
+                                         "failed to get indices");
     }
 
-    ArrayRef<int64_t> inputShape =
-        extractStridedSliceOp.getSourceVectorType().getShape();
-
-    ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
-
-    SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
-        outputShape, inputShape, offsets.value());
-
     Value srcVector = adaptor.getVector();
-    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
-        extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
+    rewriter.replaceOpWithNewOp<vector::ShuffleOp>(extractStridedSliceOp,
+                                                   flatOutputType, srcVector,
+                                                   srcVector, indices.value());
     return success();
   }
 };
@@ -298,31 +207,24 @@ struct LinearizeVectorInsertStridedSlice final
 
     // Expect a legalization failure if the strides are not all 1 (if ever the
     // verifier for insert_strided_slice allows non-1 strides).
-    if (!stridesAllOne(insertStridedSliceOp)) {
+    if (insertStridedSliceOp.hasNonUnitStrides()) {
       return rewriter.notifyMatchFailure(
           insertStridedSliceOp,
           "insert_strided_slice with strides != 1 not supported");
     }
 
-    VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
-    ArrayRef<int64_t> inputShape = inputType.getShape();
-
     VectorType outputType = insertStridedSliceOp.getType();
-    ArrayRef<int64_t> outputShape = outputType.getShape();
     int64_t nOutputElements = outputType.getNumElements();
 
-    FailureOr<SmallVector<int64_t>> offsets =
-        intsFromArrayAttr(insertStridedSliceOp.getOffsets());
-    if (failed(offsets)) {
+    FailureOr<SmallVector<int64_t>> sliceIndices =
+        insertStridedSliceOp.getLinearIndices();
+    if (failed(sliceIndices))
       return rewriter.notifyMatchFailure(insertStridedSliceOp,
-                                         "failed to get integer offsets");
-    }
-    SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
-        inputShape, outputShape, offsets.value());
+                                         "failed to get indices");
 
     SmallVector<int64_t> indices(nOutputElements);
     std::iota(indices.begin(), indices.end(), 0);
-    for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
+    for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices.value())) {
       indices[sliceIndex] = index + nOutputElements;
     }
 

@newling newling closed this May 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants