Skip to content

[mlir][MemRef] Use specialized index ops to fold expand/collapse_shape #138930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

krzysz00
Copy link
Contributor

@krzysz00 krzysz00 commented May 7, 2025

This PR updates the FoldMemRefAliasOps to use affine.linearize_index and affine.delinearize_index to perform the index computations needed to fold a memref.expand_shape or memref.collapse_shape into its consumers, respectively.

This also loosens some limitations of the pass:

  1. The existing output_shape argument to memref.expand_shape is now used, eliminating the need to re-infer this shape or call memref.dim.
  2. Because we're using affine.delinearize_index, the restriction that each group in a memref.collapse_shape can only have one dynamic dimension is removed.

@krzysz00
Copy link
Contributor Author

krzysz00 commented May 7, 2025

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir-memref

Author: Krzysztof Drewniak (krzysz00)

Changes

This PR updates the FoldMemRefAliasOps to use affine.linearize_index and affine.delinearize_index to perform the index computations needed to fold a memref.expand_shape or memref.collapse_shape into its consumers, respectively.

This also loosens some limitations of the pass:

  1. The existing output_shape argument to memref.expand_shape is now used, eliminating the need to re-infer this shape or call memref.dim.
  2. Because we're using affine.delinearize_index, the restriction that each group in a memref.collapse_shape can only have one dynamic dimension is removed.

Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+10-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+49-120)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+64-65)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
     according to specified offsets, sizes, and strides.
 
     ```mlir
-    %result1 = memref.reinterpret_cast %arg0 to 
+    %result1 = memref.reinterpret_cast %arg0 to
       offset: [9],
       sizes: [4, 4],
       strides: [16, 2]
     : memref<8x8xf32, strided<[8, 1], offset: 0>> to
       memref<4x4xf32, strided<[16, 2], offset: 9>>
 
-    %result2 = memref.reinterpret_cast %result1 to 
+    %result2 = memref.reinterpret_cast %result1 to
       offset: [0],
       sizes: [2, 2],
       strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
         OpBuilder &b, Location loc, MemRefType expandedType,
         ArrayRef<ReassociationIndices> reassociation,
         ArrayRef<OpFoldResult> inputShape);
+
+    // Return a vector with all the static and dynamic values in the output shape.
+    SmallVector<OpFoldResult> getMixedOutputShape() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+    }
   }];
 
   let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
-    
+
     The number of indices must match the rank of the memref. The indices must
     be in-bounds: `0 <= idx < dim_size`
 
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     Unlike the `reinterpret_cast`, the values are relative to the strided
     memref of the input (`%result1` in this case) and not its
     underlying memory.
-    
+
     Example 2:
 
     ```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
 ///
 /// %2 = load %0[6 * i1 + i2, %i3] :
 ///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
-                                memref::ExpandShapeOp expandShapeOp,
-                                ValueRange indices,
-                                SmallVectorImpl<Value> &sourceIndices) {
-  // Record the rewriter context for constructing ops later.
-  MLIRContext *ctx = rewriter.getContext();
-
-  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
-  // This is done for the purpose of inferring the output shape via
-  // `inferExpandOutputShape` which will in turn be used for suffix product
-  // calculation later.
-  SmallVector<OpFoldResult> srcShape;
-  MemRefType srcType = expandShapeOp.getSrcType();
-
-  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
-    if (srcType.isDynamicDim(i)) {
-      srcShape.push_back(
-          rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
-              .getResult());
-    } else {
-      srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
-    }
-  }
-
-  auto outputShape = inferExpandShapeOutputShape(
-      rewriter, loc, expandShapeOp.getResultType(),
-      expandShapeOp.getReassociationIndices(), srcShape);
-  if (!outputShape.has_value())
-    return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
 
   // Traverse all reassociation groups to determine the appropriate indices
   // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    // Flag to indicate the presence of dynamic dimensions in current
-    // reassociation group.
-    int64_t groupSize = groups.size();
-
-    // Group output dimensions utilized in this reassociation group for suffix
-    // product calculation.
-    SmallVector<OpFoldResult> sizesVal(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i) {
-      sizesVal[i] = (*outputShape)[groups[i]];
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
     }
-
-    // Calculate suffix product of relevant output dimension sizes.
-    SmallVector<OpFoldResult> suffixProduct =
-        memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
-    // Create affine expression variables for dimensions and symbols in the
-    // newly constructed affine map.
-    SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
-    bindDimsList<AffineExpr>(ctx, dims);
-    bindSymbolsList<AffineExpr>(ctx, symbols);
-
-    // Linearize binded dimensions and symbols to construct the resultant
-    // affine expression for this indice.
-    AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
-    // Record the load index corresponding to each dimension in the
-    // reassociation group. These are later supplied as operands to the affine
-    // map used for calulating relevant index post op folding.
-    SmallVector<OpFoldResult> dynamicIndices(groupSize);
-    for (int64_t i = 0; i < groupSize; i++)
-      dynamicIndices[i] = indices[groups[i]];
-
-    // Supply suffix product results followed by load op indices as operands
-    // to the map.
-    SmallVector<OpFoldResult> mapOperands;
-    llvm::append_range(mapOperands, suffixProduct);
-    llvm::append_range(mapOperands, dynamicIndices);
-
-    // Creating maximally folded and composed affine.apply composes better
-    // with other transformations without interleaving canonicalization
-    // passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
-        mapOperands);
-
-    // Push index value in the op post folding corresponding to this
-    // reassociation group.
-    sourceIndices.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
   }
   return success();
 }
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
                                   memref::CollapseShapeOp collapseShapeOp,
                                   ValueRange indices,
                                   SmallVectorImpl<Value> &sourceIndices) {
-  int64_t cnt = 0;
-  SmallVector<OpFoldResult> dynamicIndices;
-  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    dynamicIndices.push_back(indices[cnt++]);
-    int64_t groupSize = groups.size();
-
-    // Calculate suffix product for all collapse op source dimension sizes
-    // except the most major one of each group.
-    // We allow the most major source dimension to be dynamic but enforce all
-    // others to be known statically.
-    SmallVector<int64_t> sizes(groupSize, 1);
-    for (int64_t i = 1; i < groupSize; ++i) {
-      sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
-      if (sizes[i] == ShapedType::kDynamic)
-        return failure();
-    }
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
-    // Derive the index values along all dimensions of the source corresponding
-    // to the index wrt to collapsed shape op output.
-    auto d0 = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
-    // Construct the AffineApplyOp for each delinearizingExpr.
-    for (int64_t i = 0; i < groupSize; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc,
-          AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
-                         delinearizingExprs[i]),
-          dynamicIndices);
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  MemRefType sourceType = collapseShapeOp.getSrcType();
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
     }
-    dynamicIndices.clear();
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
   }
   if (collapseShapeOp.getReassociationIndices().empty()) {
     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
     int64_t srcRank =
         cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
     for (int64_t i = 0; i < srcRank; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, zeroAffineMap, dynamicIndices);
       sourceIndices.push_back(
           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
     }
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.load and affine.load guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
       .Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.store and affine.store guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
       .Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
   %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
   %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
 // CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
   %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
   %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return %[[VAL1]] : f32
 
 // -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
   memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return
 }
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return
 
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
 // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
 func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
   }
   return
 }
+// CHECK-NEXT:   %[[C0:.*]] = arith.constant 0...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 7, 2025

@llvm/pr-subscribers-mlir

Author: Krzysztof Drewniak (krzysz00)

Changes

This PR updates the FoldMemRefAliasOps to use affine.linearize_index and affine.delinearize_index to perform the index computations needed to fold a memref.expand_shape or memref.collapse_shape into its consumers, respectively.

This also loosens some limitations of the pass:

  1. The existing output_shape argument to memref.expand_shape is now used, eliminating the need to re-infer this shape or call memref.dim.
  2. Because we're using affine.delinearize_index, the restriction that each group in a memref.collapse_shape can only have one dynamic dimension is removed.

Patch is 31.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138930.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td (+10-4)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp (+49-120)
  • (modified) mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir (+64-65)
diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
index d6d8161d3117b..f34b5b46cab50 100644
--- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
+++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
@@ -1342,14 +1342,14 @@ def MemRef_ReinterpretCastOp
     according to specified offsets, sizes, and strides.
 
     ```mlir
-    %result1 = memref.reinterpret_cast %arg0 to 
+    %result1 = memref.reinterpret_cast %arg0 to
       offset: [9],
       sizes: [4, 4],
       strides: [16, 2]
     : memref<8x8xf32, strided<[8, 1], offset: 0>> to
       memref<4x4xf32, strided<[16, 2], offset: 9>>
 
-    %result2 = memref.reinterpret_cast %result1 to 
+    %result2 = memref.reinterpret_cast %result1 to
       offset: [0],
       sizes: [2, 2],
       strides: [4, 2]
@@ -1755,6 +1755,12 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape", [
         OpBuilder &b, Location loc, MemRefType expandedType,
         ArrayRef<ReassociationIndices> reassociation,
         ArrayRef<OpFoldResult> inputShape);
+
+    // Return a vector with all the static and dynamic values in the output shape.
+    SmallVector<OpFoldResult> getMixedOutputShape() {
+      OpBuilder builder(getContext());
+      return ::mlir::getMixedValues(getStaticOutputShape(), getOutputShape(), builder);
+    }
   }];
 
   let hasVerifier = 1;
@@ -1873,7 +1879,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
   let summary = "store operation";
   let description = [{
     The `store` op stores an element into a memref at the specified indices.
-    
+
     The number of indices must match the rank of the memref. The indices must
     be in-bounds: `0 <= idx < dim_size`
 
@@ -2025,7 +2031,7 @@ def SubViewOp : MemRef_OpWithOffsetSizesAndStrides<"subview", [
     Unlike the `reinterpret_cast`, the values are relative to the strided
     memref of the input (`%result1` in this case) and not its
     underlying memory.
-    
+
     Example 2:
 
     ```mlir
diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
index e4fb3f9bb87ed..2acb90613e5d1 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp
@@ -59,92 +59,28 @@ using namespace mlir;
 ///
 /// %2 = load %0[6 * i1 + i2, %i3] :
 ///          memref<12x42xf32>
-static LogicalResult
-resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter,
-                                memref::ExpandShapeOp expandShapeOp,
-                                ValueRange indices,
-                                SmallVectorImpl<Value> &sourceIndices) {
-  // Record the rewriter context for constructing ops later.
-  MLIRContext *ctx = rewriter.getContext();
-
-  // Capture expand_shape's input dimensions as `SmallVector<OpFoldResult>`.
-  // This is done for the purpose of inferring the output shape via
-  // `inferExpandOutputShape` which will in turn be used for suffix product
-  // calculation later.
-  SmallVector<OpFoldResult> srcShape;
-  MemRefType srcType = expandShapeOp.getSrcType();
-
-  for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) {
-    if (srcType.isDynamicDim(i)) {
-      srcShape.push_back(
-          rewriter.create<memref::DimOp>(loc, expandShapeOp.getSrc(), i)
-              .getResult());
-    } else {
-      srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i]));
-    }
-  }
-
-  auto outputShape = inferExpandShapeOutputShape(
-      rewriter, loc, expandShapeOp.getResultType(),
-      expandShapeOp.getReassociationIndices(), srcShape);
-  if (!outputShape.has_value())
-    return failure();
+static LogicalResult resolveSourceIndicesExpandShape(
+    Location loc, PatternRewriter &rewriter,
+    memref::ExpandShapeOp expandShapeOp, ValueRange indices,
+    SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) {
+  SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape();
 
   // Traverse all reassociation groups to determine the appropriate indices
   // corresponding to each one of them post op folding.
-  for (ArrayRef<int64_t> groups : expandShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    // Flag to indicate the presence of dynamic dimensions in current
-    // reassociation group.
-    int64_t groupSize = groups.size();
-
-    // Group output dimensions utilized in this reassociation group for suffix
-    // product calculation.
-    SmallVector<OpFoldResult> sizesVal(groupSize);
-    for (int64_t i = 0; i < groupSize; ++i) {
-      sizesVal[i] = (*outputShape)[groups[i]];
+  for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+    if (groupSize == 1) {
+      sourceIndices.push_back(indices[group[0]]);
+      continue;
     }
-
-    // Calculate suffix product of relevant output dimension sizes.
-    SmallVector<OpFoldResult> suffixProduct =
-        memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal);
-
-    // Create affine expression variables for dimensions and symbols in the
-    // newly constructed affine map.
-    SmallVector<AffineExpr> dims(groupSize), symbols(groupSize);
-    bindDimsList<AffineExpr>(ctx, dims);
-    bindSymbolsList<AffineExpr>(ctx, symbols);
-
-    // Linearize binded dimensions and symbols to construct the resultant
-    // affine expression for this indice.
-    AffineExpr srcIndexExpr = linearize(ctx, dims, symbols);
-
-    // Record the load index corresponding to each dimension in the
-    // reassociation group. These are later supplied as operands to the affine
-    // map used for calulating relevant index post op folding.
-    SmallVector<OpFoldResult> dynamicIndices(groupSize);
-    for (int64_t i = 0; i < groupSize; i++)
-      dynamicIndices[i] = indices[groups[i]];
-
-    // Supply suffix product results followed by load op indices as operands
-    // to the map.
-    SmallVector<OpFoldResult> mapOperands;
-    llvm::append_range(mapOperands, suffixProduct);
-    llvm::append_range(mapOperands, dynamicIndices);
-
-    // Creating maximally folded and composed affine.apply composes better
-    // with other transformations without interleaving canonicalization
-    // passes.
-    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-        rewriter, loc,
-        AffineMap::get(/*numDims=*/groupSize,
-                       /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr),
-        mapOperands);
-
-    // Push index value in the op post folding corresponding to this
-    // reassociation group.
-    sourceIndices.push_back(
-        getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+    SmallVector<OpFoldResult> groupBasis =
+        llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; });
+    SmallVector<Value> groupIndices =
+        llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; });
+    Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>(
+        loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds);
+    sourceIndices.push_back(collapsedIndex);
   }
   return success();
 }
@@ -167,49 +103,34 @@ resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter,
                                   memref::CollapseShapeOp collapseShapeOp,
                                   ValueRange indices,
                                   SmallVectorImpl<Value> &sourceIndices) {
-  int64_t cnt = 0;
-  SmallVector<OpFoldResult> dynamicIndices;
-  for (ArrayRef<int64_t> groups : collapseShapeOp.getReassociationIndices()) {
-    assert(!groups.empty() && "association indices groups cannot be empty");
-    dynamicIndices.push_back(indices[cnt++]);
-    int64_t groupSize = groups.size();
-
-    // Calculate suffix product for all collapse op source dimension sizes
-    // except the most major one of each group.
-    // We allow the most major source dimension to be dynamic but enforce all
-    // others to be known statically.
-    SmallVector<int64_t> sizes(groupSize, 1);
-    for (int64_t i = 1; i < groupSize; ++i) {
-      sizes[i] = collapseShapeOp.getSrcType().getDimSize(groups[i]);
-      if (sizes[i] == ShapedType::kDynamic)
-        return failure();
-    }
-    SmallVector<int64_t> suffixProduct = computeSuffixProduct(sizes);
-
-    // Derive the index values along all dimensions of the source corresponding
-    // to the index wrt to collapsed shape op output.
-    auto d0 = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, suffixProduct);
-
-    // Construct the AffineApplyOp for each delinearizingExpr.
-    for (int64_t i = 0; i < groupSize; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc,
-          AffineMap::get(/*numDims=*/1, /*numSymbols=*/0,
-                         delinearizingExprs[i]),
-          dynamicIndices);
-      sourceIndices.push_back(
-          getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
+  MemRefType sourceType = collapseShapeOp.getSrcType();
+  // Note: collapse_shape requires a strided memref, we can do this.
+  auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>(
+      loc, collapseShapeOp.getSrc());
+  SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes();
+  for (auto [index, group] :
+       llvm::zip(indices, collapseShapeOp.getReassociationIndices())) {
+    assert(!group.empty() && "association indices groups cannot be empty");
+    int64_t groupSize = group.size();
+
+    if (groupSize == 1) {
+      sourceIndices.push_back(index);
+      continue;
     }
-    dynamicIndices.clear();
+
+    SmallVector<OpFoldResult> basis =
+        llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
+    auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
+        loc, index, basis, /*hasOuterBound=*/true);
+    llvm::append_range(sourceIndices, delinearize.getResults());
   }
   if (collapseShapeOp.getReassociationIndices().empty()) {
     auto zeroAffineMap = rewriter.getConstantAffineMap(0);
     int64_t srcRank =
         cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank();
+    OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{});
     for (int64_t i = 0; i < srcRank; i++) {
-      OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, zeroAffineMap, dynamicIndices);
       sourceIndices.push_back(
           getValueOrCreateConstantIndexOp(rewriter, loc, ofr));
     }
@@ -513,8 +434,12 @@ LogicalResult LoadOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.load and affine.load guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          loadOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineLoadOp, memref::LoadOp>(loadOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(loadOp)
       .Case([&](affine::AffineLoadOp op) {
@@ -676,8 +601,12 @@ LogicalResult StoreOpOfExpandShapeOpFolder<OpTy>::matchAndRewrite(
     indices.assign(expandedIndices.begin(), expandedIndices.end());
   }
   SmallVector<Value> sourceIndices;
+  // memref.store and affine.store guarantee that indexes start inbounds
+  // while the vector operations don't. This impacts if our linearization
+  // is `disjoint`
   if (failed(resolveSourceIndicesExpandShape(
-          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices)))
+          storeOp.getLoc(), rewriter, expandShapeOp, indices, sourceIndices,
+          isa<affine::AffineStoreOp, memref::StoreOp>(storeOp.getOperation()))))
     return failure();
   llvm::TypeSwitch<Operation *, void>(storeOp)
       .Case([&](affine::AffineStoreOp op) {
diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
index a27fbf26e13d8..106652623933f 100644
--- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
+++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir
@@ -408,7 +408,6 @@ func.func @fold_static_stride_subview_with_affine_load_store(%arg0 : memref<12x3
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1] -> (s0 * 6 + s1)>
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index) -> f32 {
@@ -416,14 +415,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0
   %1 = affine.load %0[%arg1, %arg2, %arg3] : memref<2x6x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]]] by (2, 6)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG3]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_static_stride_subview_with_affine_load_store_collapse_shape
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg0 : memref<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -431,15 +428,12 @@ func.func @fold_static_stride_subview_with_affine_load_store_collapse_shape(%arg
   %1 = affine.load %0[%arg1, %arg2] : memref<12x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<2x6x32xf32>
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<2x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 floordiv 6)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 mod 6)>
 // CHECK-LABEL: @fold_dynamic_size_collapse_shape_with_affine_load
 // CHECK-SAME: (%[[ARG0:.*]]: memref<?x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
 func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
@@ -447,14 +441,28 @@ func.func @fold_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x6x
   %1 = affine.load %0[%arg1, %arg2] : memref<?x32xf32>
   return %1 : f32
 }
-// CHECK-NEXT: %[[MODIFIED_INDEX0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[MODIFIED_INDEX1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG1]]]
-// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEX0]], %[[MODIFIED_INDEX1]], %[[ARG2]]] : memref<?x6x32xf32>
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, 6)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x6x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
 // -----
 
-// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 6 + s1 * 3 + s2)>
+// CHECK-LABEL: @fold_fully_dynamic_size_collapse_shape_with_affine_load
+// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
+func.func @fold_fully_dynamic_size_collapse_shape_with_affine_load(%arg0 : memref<?x?x?xf32>, %arg1 : index, %arg2 : index) -> f32 {
+  %0 = memref.collapse_shape %arg0 [[0, 1], [2]] : memref<?x?x?xf32> into memref<?x?xf32>
+  %1 = affine.load %0[%arg1, %arg2] : memref<?x?xf32>
+  return %1 : f32
+}
+// CHECK-NEXT: %{{.*}}, %{{.*}}, %[[SIZES:.*]]:3, %{{.*}}:3 = memref.extract_strided_metadata %[[ARG0]]
+// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (%[[SIZES]]#0, %[[SIZES]]#1)
+// CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : memref<?x?x?xf32>
+// CHECK-NEXT: return %[[RESULT]] : f32
+
+
+// -----
+
 // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_3d
 // CHECK-SAME: (%[[ARG0:.*]]: memref<12x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) -> f32 {
 func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4: index) -> f32 {
@@ -462,7 +470,7 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
   %1 = affine.load %0[%arg1, %arg2, %arg3, %arg4] : memref<2x2x3x32xf32>
   return %1 : f32
 }
-// CHECK: %[[INDEX:.*]] = affine.apply #[[$MAP]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]]]
+// CHECK: %[[INDEX:.*]] = affine.linearize_index disjoint [%[[ARG1]], %[[ARG2]], %[[ARG3]]] by (2, 2, 3)
 // CHECK-NEXT: %[[RESULT:.*]] = affine.load %[[ARG0]][%[[INDEX]], %[[ARG4]]] : memref<12x32xf32>
 // CHECK-NEXT: return %[[RESULT]] : f32
 
@@ -476,7 +484,10 @@ func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x
   %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return %0 : f32
 }
-// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-NEXT: %[[C0:.*]] = arith.constant 0
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return %[[VAL1]] : f32
 
 // -----
@@ -490,14 +501,16 @@ func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16
   memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] {nontemporal = true} : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
   return
 }
-// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
-// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
+// CHECK-NEXT: %[[INDEX1:.*]] = affine.linearize_index disjoint [%[[C0]], %[[ARG1]]] by (1, 16)
+// CHECK-NEXT: %[[INDEX2:.*]] = affine.linearize_index disjoint [%[[ARG2]], %[[C0]]] by (%[[ARG3]], 1)
+// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[INDEX1]], %[[INDEX2]]] {nontemporal = true} : memref<16x?xf32, strided<[16, 1]>>
 // CHECK-NEXT: return
 
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0)[s0] -> (d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
 // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
 // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
 func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) {
@@ -513,21 +526,20 @@ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc:
   }
   return
 }
+// CHECK-NEXT:   %[[C0:.*]] = arith.constant 0...
[truncated]

@krzysz00
Copy link
Contributor Author

krzysz00 commented May 9, 2025

Ping

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good to me. It isn't strictly required, by given that book h of us work on the same downstream project, does this pass with the said downstream project. But this looks good to me

Base automatically changed from users/krzysz00/linearize-delinearize-dims to main May 13, 2025 16:13
This PR updates the FoldMemRefAliasOps to use `affine.linearize_index`
and `affine.delinearize_index` to perform the index computations
needed to fold a `memref.expand_shape` or `memref.collapse_shape` into
its consumers, respectively.

This also loosens some limitations of the pass:
1. The existing `output_shape` argument to `memref.expand_shape` is
now used, eliminating the need to re-infer this shape or call
`memref.dim`.
2. Because we're using `affine.delinearize_index`, the restriction
that each group in a `memref.collapse_shape` can only have one dynamic
dimension is removed.
@krzysz00 krzysz00 force-pushed the users/krzysz00/linearize-delinearize-alias-ops-folder branch from ae27c6f to 249e426 Compare May 13, 2025 16:16
@krzysz00 krzysz00 merged commit a891163 into main May 13, 2025
9 of 10 checks passed
@krzysz00 krzysz00 deleted the users/krzysz00/linearize-delinearize-alias-ops-folder branch May 13, 2025 18:28
@kazutakahirata
Copy link
Contributor

@krzysz00 I've landed 2534839 to fix a warning from this PR. Thanks!

@krzysz00
Copy link
Contributor Author

Thanks!

krzysz00 added a commit to iree-org/iree that referenced this pull request May 16, 2025
Integrate torch-mlir@389541fb9ddd

Integrate stablehlo@5837b2a6ce

# Dropped reverts
- Drop revert of the 1:N dialect conversion removal since all
dependencies have migrated
- Drop revert of the APIntParameter error since all dependencies have
migrated
- Drop revert of allowing function type conversion to fail since all
dependencies have migrated
- Drop local modifications to torch-mlir and stablehlo, those are now
clean submodules (they're no longer needed now that we've dropped
reverts and they've migrated)

# Continued reverts
- We still have a revert of upstream #137930 since it's not clear the
roccertness issue is resolved (that's the SDWA cndmask thing, see
llvm/llvm-project#138766 )
- We still have a revert of upstream #133231 since that could still be
breaking tests

# Changes
- Rename bufferization.to_memref to bufferization.to_buffer everewhere
- Swap all getSource() on Transfer*Op to getBase()
- Rename the memref argument of TransferGatherOp to `base` to match the
transfer interface
- Remove argument materialization calls
- In the one case where this wasn't trivial, migrate to a target
materilazition since that's what upstream advice suggested
- Handle (in any way that seemed appropriate) the failures that
eraseArguments() and eraseResults() can now have
- Slightly reshuffle the LLVMCPU pipeline so that there's an
`affine-expand-index-ops` after the last `FoldMemRefAliasOps` call,
because `FoldMemRefAliasOps` now creates `affine.linearize_index` and
`affine.delinearize_index` which don't seem to lower to LLVM right on
their own. See llvm/llvm-project#138930
- Update narrow type emulation tests to account for correctness fixes in
the lineraized shape determination widget
- Update vectorization tests to account for changes in ee47454bb8be and
update the attention tiling test to account for an unknown change
- Update StableHLO rewrites to account for accuracy arguments
qedawkins added a commit to iree-org/llvm-project that referenced this pull request Jun 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants