Skip to content

[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. #127943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

MaheshRavishankar
Copy link
Contributor

With tensor.expand_shape allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension.

@llvmbot
Copy link
Member

llvmbot commented Feb 20, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

With tensor.expand_shape allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension.


Patch is 46.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127943.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+70-107)
  • (modified) mlir/test/Dialect/Linalg/reshape_fusion.mlir (+61-188)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 60cae77644291..b4da6d3d37354 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -595,18 +595,17 @@ class ExpansionInfo {
   // the expanded op.
   LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
                         ArrayRef<AffineMap> reassociationMaps,
-                        ArrayRef<int64_t> expandedShape,
-                        ArrayRef<int64_t> collapsedShape,
+                        ArrayRef<OpFoldResult> expandedShape,
                         PatternRewriter &rewriter);
   unsigned getOrigOpNumDims() const { return reassociation.size(); }
   unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
   ReassociationIndicesRef getExpandedDims(unsigned i) const {
     return reassociation[i];
   }
-  ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+  ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
     return expandedShapeMap[i];
   }
-  ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+  ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
 
 private:
   /// Reassociation from the dimensions in the original operation to the
@@ -614,9 +613,9 @@ class ExpansionInfo {
   SmallVector<ReassociationIndices> reassociation;
   /// Mapping from extent of loops in the original operation, to the extent of
   /// loops in the expanded operation.
-  SmallVector<SmallVector<int64_t>> expandedShapeMap;
+  SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
   /// Extent of the loop in the original operation.
-  SmallVector<int64_t> originalLoopExtent;
+  SmallVector<OpFoldResult> originalLoopExtent;
   unsigned expandedOpNumDims;
 };
 } // namespace
@@ -624,15 +623,17 @@ class ExpansionInfo {
 LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
                                      OpOperand *fusableOpOperand,
                                      ArrayRef<AffineMap> reassociationMaps,
-                                     ArrayRef<int64_t> expandedShape,
-                                     ArrayRef<int64_t> collapsedShape,
+                                     ArrayRef<OpFoldResult> expandedShape,
                                      PatternRewriter &rewriter) {
   if (reassociationMaps.empty())
     return failure();
   AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
 
-  SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
-  originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(linalgOp);
+  originalLoopExtent = llvm::map_to_vector(
+      linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
+      [](Range r) { return r.size; });
 
   reassociation.clear();
   expandedShapeMap.clear();
@@ -644,7 +645,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
     unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
     numExpandedDims[pos] = foldedDims.getNumResults();
-    ArrayRef<int64_t> shape =
+    ArrayRef<OpFoldResult> shape =
         expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
     expandedShapeMap[pos].assign(shape.begin(), shape.end());
   }
@@ -665,33 +666,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
   return success();
 }
 
-/// Expanding the body of a linalg operation requires adaptations of the
-/// accessed loop indices. Specifically, access of indices in the original
-/// operation need to be replaced with linearizations of indices in the expanded
-/// op. That requires the shape of the expanded dimensions to be static (at
-/// least all but the most significant). For now check that these are all
-/// statically sized. Note that this could be extended to handle dynamic case,
-/// but the implementation below uses `affine.apply` which seems to have issues
-/// when the shapes are not static.
-static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
-                                          const ExpansionInfo &expansionInfo,
-                                          PatternRewriter &rewriter) {
-  if (!linalgOp.hasIndexSemantics())
-    return success();
-  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
-    ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
-    if (expandedShape.size() == 1)
-      continue;
-    for (int64_t shape : expandedShape.drop_front()) {
-      if (ShapedType::isDynamic(shape)) {
-        return rewriter.notifyMatchFailure(
-            linalgOp, "cannot expand due to index semantics and dynamic dims");
-      }
-    }
-  }
-  return success();
-}
-
 /// Return the indexing map to use in the expanded op for a given the
 /// `indexingMap` of the original operation.
 static AffineMap
@@ -713,16 +687,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
 
 /// Return the type of the operand/result to use in the expanded op given the
 /// type in the original op.
-static RankedTensorType getExpandedType(RankedTensorType originalType,
-                                        AffineMap indexingMap,
-                                        const ExpansionInfo &expansionInfo) {
-  SmallVector<int64_t> expandedShape;
+static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
+getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
+                        const ExpansionInfo &expansionInfo) {
+  SmallVector<int64_t> expandedStaticShape;
+  SmallVector<OpFoldResult> expandedShape;
   for (AffineExpr expr : indexingMap.getResults()) {
     unsigned dim = cast<AffineDimExpr>(expr).getPosition();
-    auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+    ArrayRef<OpFoldResult> dimExpansion =
+        expansionInfo.getExpandedShapeOfDim(dim);
+    llvm::append_range(expandedStaticShape,
+                       llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
+                         std::optional<int64_t> staticShape =
+                             getConstantIntValue(ofr);
+                         if (staticShape) {
+                           return staticShape.value();
+                         }
+                         return ShapedType::kDynamic;
+                       }));
     expandedShape.append(dimExpansion.begin(), dimExpansion.end());
   }
-  return RankedTensorType::get(expandedShape, originalType.getElementType());
+  return {expandedShape, RankedTensorType::get(expandedStaticShape,
+                                               originalType.getElementType())};
 }
 
 /// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -770,49 +756,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
     // Linearize the expanded indices of the original index dimension.
     OpBuilder::InsertionGuard guard(rewriter);
     rewriter.setInsertionPointAfter(indexOp);
-    ArrayRef<int64_t> expandedDimsShape =
+    ArrayRef<OpFoldResult> expandedDimsShape =
         expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
     SmallVector<Value> expandedIndices;
     expandedIndices.reserve(expandedDims.size() - 1);
     llvm::transform(
         expandedDims.drop_front(), std::back_inserter(expandedIndices),
         [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
-    Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+    OpFoldResult newIndex =
+        rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
     for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
-      assert(!ShapedType::isDynamic(std::get<0>(it)));
-      AffineExpr idx, acc;
+      AffineExpr idx, acc, shape;
       bindDims(rewriter.getContext(), idx, acc);
-      newIndex = rewriter.create<affine::AffineApplyOp>(
-          indexOp.getLoc(), idx + acc * std::get<0>(it),
-          ValueRange{std::get<1>(it), newIndex});
-    }
-    rewriter.replaceOp(indexOp, newIndex);
-  }
-}
-
-/// Checks if a single dynamic dimension expanded into multiple dynamic
-/// dimensions.
-static LogicalResult
-validateDynamicDimExpansion(LinalgOp linalgOp,
-                            const ExpansionInfo &expansionInfo,
-                            PatternRewriter &rewriter) {
-  for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
-    ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
-    if (expandedShape.size() == 1)
-      continue;
-    bool foundDynamic = false;
-    for (int64_t shape : expandedShape) {
-      if (!ShapedType::isDynamic(shape))
-        continue;
-      if (foundDynamic) {
-        return rewriter.notifyMatchFailure(
-            linalgOp, "cannot infer expanded shape with multiple dynamic "
-                      "dims in the same reassociation group");
-      }
-      foundDynamic = true;
+      bindSymbols(rewriter.getContext(), shape);
+      newIndex = affine::makeComposedFoldedAffineApply(
+          rewriter, indexOp.getLoc(), idx + acc * shape,
+          ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
     }
+    Value newIndexVal =
+        getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
+    rewriter.replaceOp(indexOp, newIndexVal);
   }
-  return success();
 }
 
 /// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
@@ -826,31 +790,25 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
          "preconditions for fuse operation failed");
 
   Location loc = linalgOp.getLoc();
-  // Check if reshape is expanding or collapsing.
-  auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
-  auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
-  bool isExpanding = (expandingReshapeOp != nullptr);
-  RankedTensorType expandedType = isExpanding
-                                      ? expandingReshapeOp.getResultType()
-                                      : collapsingReshapeOp.getSrcType();
-  RankedTensorType collapsedType = isExpanding
-                                       ? expandingReshapeOp.getSrcType()
-                                       : collapsingReshapeOp.getResultType();
+  SmallVector<OpFoldResult> expandedShape, collapsedShape;
+  SmallVector<AffineMap, 4> reassociationIndices;
+  Value src;
+  if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+    expandedShape = expandingReshapeOp.getMixedOutputShape();
+    reassociationIndices = expandingReshapeOp.getReassociationMaps();
+    src = expandingReshapeOp.getSrc();
+  } else {
+    auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+    expandedShape = tensor::getMixedSizes(
+        rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
+    reassociationIndices = collapsingReshapeOp.getReassociationMaps();
+    src = collapsingReshapeOp.getSrc();
+  }
 
   ExpansionInfo expansionInfo;
   if (failed(expansionInfo.compute(
-          linalgOp, fusableOpOperand,
-          isExpanding ? expandingReshapeOp.getReassociationMaps()
-                      : collapsingReshapeOp.getReassociationMaps(),
-          expandedType.getShape(), collapsedType.getShape(), rewriter)))
-    return std::nullopt;
-
-  // TODO: With the support of multiple dynamic dims expansion in
-  // tensor.expand_shape op, this case can be handled.
-  if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
-    return std::nullopt;
-
-  if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
+          linalgOp, fusableOpOperand, reassociationIndices,
+          expandedShape, rewriter)))
     return std::nullopt;
 
   SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -866,15 +824,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
   for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
     if (opOperand == fusableOpOperand) {
-      expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
-                                               : collapsingReshapeOp.getSrc());
+      expandedOpOperands.push_back(src);
       continue;
     }
     if (auto opOperandType =
             dyn_cast<RankedTensorType>(opOperand->get().getType())) {
       AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
-      RankedTensorType expandedOperandType =
-          getExpandedType(opOperandType, indexingMap, expansionInfo);
+      SmallVector<OpFoldResult> expandedOperandShape;
+      RankedTensorType expandedOperandType;
+      std::tie(expandedOperandShape, expandedOperandType) =
+          getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
       if (expandedOperandType != opOperand->get().getType()) {
         // Reshape the operand to get the right type.
         SmallVector<ReassociationIndices> reassociation =
@@ -888,7 +847,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
                 /*isExpandingReshape=*/true)))
           return std::nullopt;
         expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
-            loc, expandedOperandType, opOperand->get(), reassociation));
+            loc, expandedOperandType, opOperand->get(), reassociation,
+            expandedOperandShape));
         continue;
       }
     }
@@ -899,8 +859,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
   for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
     AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
     auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
-    RankedTensorType expandedOutputType =
-        getExpandedType(opOperandType, indexingMap, expansionInfo);
+    SmallVector<OpFoldResult> expandedOutputShape;
+    RankedTensorType expandedOutputType;
+    std::tie(expandedOutputShape, expandedOutputType) =
+        getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
     if (expandedOutputType != opOperand.get().getType()) {
       SmallVector<ReassociationIndices> reassociation =
           getReassociationForExpansion(indexingMap, expansionInfo);
@@ -913,7 +875,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
               /*isExpandingReshape=*/true)))
         return std::nullopt;
       outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
-          loc, expandedOutputType, opOperand.get(), reassociation));
+          loc, expandedOutputType, opOperand.get(), reassociation,
+          expandedOutputShape));
     } else {
       outputs.push_back(opOperand.get());
     }
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ef853e4d662a7..57904f912a35b 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -30,20 +30,14 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
-//      CHECK:   %[[C4:.+]] = arith.constant 4 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
+//      CHECK:   %[[C3:.+]] = arith.constant 3 : index
 //      CHECK:   %[[C1:.+]] = arith.constant 1 : index
 //      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x4x?xf32>
+//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x4x?xf32>
+//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x4x?xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -88,21 +82,9 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: f32
 // CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-//      CHECK:   %[[C20:.+]] = arith.constant 20 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:   %[[C0:.+]] = arith.constant 0 : index
-//      CHECK:   %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//      CHECK:   %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-//      CHECK:   %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-//      CHECK:   %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-//      CHECK:   %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+//      CHECK:   %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
 //      CHECK:   %[[T3:.+]] = linalg.generic
 // CHECK-SAME:     indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
 // CHECK-SAME:     ["parallel", "parallel", "parallel", "parallel"]
@@ -137,26 +119,9 @@ func.func @reshape_as_consumer_permutation
 // CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
 // CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
 // CHECK-SAME:   %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
-//      CHECK:   %[[C12:.+]] = arith.constant 12 : index
-//      CHECK:   %[[C2:.+]] = arith.constant 2 : index
-//      CHECK:   %[[C1:.+]] = arith.constant 1 : index
-//      CHECK:  ...
[truncated]

Copy link

github-actions bot commented Feb 20, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Feb 20, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Mar 2, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the dynamicReshapePropagation branch from 09fa404 to 0e134fc Compare March 12, 2025 01:31
@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Mar 12, 2025
@MaheshRavishankar
Copy link
Contributor Author

Please review the last commit. The first commit is part of #130874

Comment on lines +884 to +880
if (failed(moveValueDefinitions(
rewriter, expandingReshapeOp.getOutputShape(), linalgOp)))
return std::nullopt;
Copy link
Contributor

@IanWood1 IanWood1 Mar 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be problematic because after moveValueDefinitions mutates the IR, the rewrite pattern is no longer allowed to return failure().

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The usage of reshapeLikeShapesAreCompatible has a similar problem below (prior to this PR). So maybe this isn't that big of a deal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, many places that happens in the IR. I dont have a good solution for this though. In general its just changing position of some tensor.dim operations which shouldnt matter too much.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, just few nits!

MaheshRavishankar and others added 2 commits March 12, 2025 21:21
…dynamic shapes.

With `tensor.expand_shape` allowing expanding dynamic dimension into
multiple dynamic dimension, adapt the reshape propagation through
expansion to handle cases where one dynamic dimension is expanded into
multiple dynamic dimension.

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the dynamicReshapePropagation branch from 0e134fc to dbaa97a Compare March 13, 2025 05:11
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Mar 13, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Mar 13, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Mar 14, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Mar 14, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar merged commit 2490f7f into llvm:main Mar 14, 2025
11 checks passed
qedawkins pushed a commit to iree-org/iree that referenced this pull request Mar 18, 2025
Signed-off-by: MaheshRavishankar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants