-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. #127943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Linalg] Allow expand shape propagation across linalg ops with dynamic shapes. #127943
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesWith Patch is 46.75 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/127943.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 60cae77644291..b4da6d3d37354 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -595,18 +595,17 @@ class ExpansionInfo {
// the expanded op.
LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter);
unsigned getOrigOpNumDims() const { return reassociation.size(); }
unsigned getExpandedOpNumDims() const { return expandedOpNumDims; }
ReassociationIndicesRef getExpandedDims(unsigned i) const {
return reassociation[i];
}
- ArrayRef<int64_t> getExpandedShapeOfDim(unsigned i) const {
+ ArrayRef<OpFoldResult> getExpandedShapeOfDim(unsigned i) const {
return expandedShapeMap[i];
}
- ArrayRef<int64_t> getOriginalShape() const { return originalLoopExtent; }
+ ArrayRef<OpFoldResult> getOriginalShape() const { return originalLoopExtent; }
private:
/// Reassociation from the dimensions in the original operation to the
@@ -614,9 +613,9 @@ class ExpansionInfo {
SmallVector<ReassociationIndices> reassociation;
/// Mapping from extent of loops in the original operation, to the extent of
/// loops in the expanded operation.
- SmallVector<SmallVector<int64_t>> expandedShapeMap;
+ SmallVector<SmallVector<OpFoldResult>> expandedShapeMap;
/// Extent of the loop in the original operation.
- SmallVector<int64_t> originalLoopExtent;
+ SmallVector<OpFoldResult> originalLoopExtent;
unsigned expandedOpNumDims;
};
} // namespace
@@ -624,15 +623,17 @@ class ExpansionInfo {
LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
OpOperand *fusableOpOperand,
ArrayRef<AffineMap> reassociationMaps,
- ArrayRef<int64_t> expandedShape,
- ArrayRef<int64_t> collapsedShape,
+ ArrayRef<OpFoldResult> expandedShape,
PatternRewriter &rewriter) {
if (reassociationMaps.empty())
return failure();
AffineMap fusedIndexMap = linalgOp.getMatchingIndexingMap(fusableOpOperand);
- SmallVector<int64_t, 4> originalLoopRange = linalgOp.getStaticLoopRanges();
- originalLoopExtent.assign(originalLoopRange.begin(), originalLoopRange.end());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(linalgOp);
+ originalLoopExtent = llvm::map_to_vector(
+ linalgOp.createLoopRanges(rewriter, linalgOp->getLoc()),
+ [](Range r) { return r.size; });
reassociation.clear();
expandedShapeMap.clear();
@@ -644,7 +645,7 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
unsigned pos = cast<AffineDimExpr>(resultExpr.value()).getPosition();
AffineMap foldedDims = reassociationMaps[resultExpr.index()];
numExpandedDims[pos] = foldedDims.getNumResults();
- ArrayRef<int64_t> shape =
+ ArrayRef<OpFoldResult> shape =
expandedShape.slice(foldedDims.getDimPosition(0), numExpandedDims[pos]);
expandedShapeMap[pos].assign(shape.begin(), shape.end());
}
@@ -665,33 +666,6 @@ LogicalResult ExpansionInfo::compute(LinalgOp linalgOp,
return success();
}
-/// Expanding the body of a linalg operation requires adaptations of the
-/// accessed loop indices. Specifically, access of indices in the original
-/// operation need to be replaced with linearizations of indices in the expanded
-/// op. That requires the shape of the expanded dimensions to be static (at
-/// least all but the most significant). For now check that these are all
-/// statically sized. Note that this could be extended to handle dynamic case,
-/// but the implementation below uses `affine.apply` which seems to have issues
-/// when the shapes are not static.
-static LogicalResult isLinalgOpExpandable(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- if (!linalgOp.hasIndexSemantics())
- return success();
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- for (int64_t shape : expandedShape.drop_front()) {
- if (ShapedType::isDynamic(shape)) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot expand due to index semantics and dynamic dims");
- }
- }
- }
- return success();
-}
-
/// Return the indexing map to use in the expanded op for a given the
/// `indexingMap` of the original operation.
static AffineMap
@@ -713,16 +687,28 @@ getIndexingMapInExpandedOp(OpBuilder &builder, AffineMap indexingMap,
/// Return the type of the operand/result to use in the expanded op given the
/// type in the original op.
-static RankedTensorType getExpandedType(RankedTensorType originalType,
- AffineMap indexingMap,
- const ExpansionInfo &expansionInfo) {
- SmallVector<int64_t> expandedShape;
+static std::tuple<SmallVector<OpFoldResult>, RankedTensorType>
+getExpandedShapeAndType(RankedTensorType originalType, AffineMap indexingMap,
+ const ExpansionInfo &expansionInfo) {
+ SmallVector<int64_t> expandedStaticShape;
+ SmallVector<OpFoldResult> expandedShape;
for (AffineExpr expr : indexingMap.getResults()) {
unsigned dim = cast<AffineDimExpr>(expr).getPosition();
- auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim);
+ ArrayRef<OpFoldResult> dimExpansion =
+ expansionInfo.getExpandedShapeOfDim(dim);
+ llvm::append_range(expandedStaticShape,
+ llvm::map_range(dimExpansion, [](OpFoldResult ofr) {
+ std::optional<int64_t> staticShape =
+ getConstantIntValue(ofr);
+ if (staticShape) {
+ return staticShape.value();
+ }
+ return ShapedType::kDynamic;
+ }));
expandedShape.append(dimExpansion.begin(), dimExpansion.end());
}
- return RankedTensorType::get(expandedShape, originalType.getElementType());
+ return {expandedShape, RankedTensorType::get(expandedStaticShape,
+ originalType.getElementType())};
}
/// Returns the reassociation maps to use in the `tensor.expand_shape`
@@ -770,49 +756,27 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter,
// Linearize the expanded indices of the original index dimension.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(indexOp);
- ArrayRef<int64_t> expandedDimsShape =
+ ArrayRef<OpFoldResult> expandedDimsShape =
expansionInfo.getExpandedShapeOfDim(indexOp.getDim()).drop_front();
SmallVector<Value> expandedIndices;
expandedIndices.reserve(expandedDims.size() - 1);
llvm::transform(
expandedDims.drop_front(), std::back_inserter(expandedIndices),
[&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); });
- Value newIndex = rewriter.create<IndexOp>(loc, expandedDims.front());
+ OpFoldResult newIndex =
+ rewriter.create<IndexOp>(loc, expandedDims.front()).getResult();
for (auto it : llvm::zip(expandedDimsShape, expandedIndices)) {
- assert(!ShapedType::isDynamic(std::get<0>(it)));
- AffineExpr idx, acc;
+ AffineExpr idx, acc, shape;
bindDims(rewriter.getContext(), idx, acc);
- newIndex = rewriter.create<affine::AffineApplyOp>(
- indexOp.getLoc(), idx + acc * std::get<0>(it),
- ValueRange{std::get<1>(it), newIndex});
- }
- rewriter.replaceOp(indexOp, newIndex);
- }
-}
-
-/// Checks if a single dynamic dimension expanded into multiple dynamic
-/// dimensions.
-static LogicalResult
-validateDynamicDimExpansion(LinalgOp linalgOp,
- const ExpansionInfo &expansionInfo,
- PatternRewriter &rewriter) {
- for (unsigned i : llvm::seq<unsigned>(0, expansionInfo.getOrigOpNumDims())) {
- ArrayRef<int64_t> expandedShape = expansionInfo.getExpandedShapeOfDim(i);
- if (expandedShape.size() == 1)
- continue;
- bool foundDynamic = false;
- for (int64_t shape : expandedShape) {
- if (!ShapedType::isDynamic(shape))
- continue;
- if (foundDynamic) {
- return rewriter.notifyMatchFailure(
- linalgOp, "cannot infer expanded shape with multiple dynamic "
- "dims in the same reassociation group");
- }
- foundDynamic = true;
+ bindSymbols(rewriter.getContext(), shape);
+ newIndex = affine::makeComposedFoldedAffineApply(
+ rewriter, indexOp.getLoc(), idx + acc * shape,
+ ArrayRef<OpFoldResult>{std::get<1>(it), newIndex, std::get<0>(it)});
}
+ Value newIndexVal =
+ getValueOrCreateConstantIndexOp(rewriter, indexOp.getLoc(), newIndex);
+ rewriter.replaceOp(indexOp, newIndexVal);
}
- return success();
}
/// Implements the fusion of a tensor.collapse_shape or a tensor.expand_shape op
@@ -826,31 +790,25 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
"preconditions for fuse operation failed");
Location loc = linalgOp.getLoc();
- // Check if reshape is expanding or collapsing.
- auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(*reshapeOp);
- auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(*reshapeOp);
- bool isExpanding = (expandingReshapeOp != nullptr);
- RankedTensorType expandedType = isExpanding
- ? expandingReshapeOp.getResultType()
- : collapsingReshapeOp.getSrcType();
- RankedTensorType collapsedType = isExpanding
- ? expandingReshapeOp.getSrcType()
- : collapsingReshapeOp.getResultType();
+ SmallVector<OpFoldResult> expandedShape, collapsedShape;
+ SmallVector<AffineMap, 4> reassociationIndices;
+ Value src;
+ if (auto expandingReshapeOp = dyn_cast<tensor::ExpandShapeOp>(reshapeOp)) {
+ expandedShape = expandingReshapeOp.getMixedOutputShape();
+ reassociationIndices = expandingReshapeOp.getReassociationMaps();
+ src = expandingReshapeOp.getSrc();
+ } else {
+ auto collapsingReshapeOp = dyn_cast<tensor::CollapseShapeOp>(reshapeOp);
+ expandedShape = tensor::getMixedSizes(
+ rewriter, collapsingReshapeOp->getLoc(), collapsingReshapeOp.getSrc());
+ reassociationIndices = collapsingReshapeOp.getReassociationMaps();
+ src = collapsingReshapeOp.getSrc();
+ }
ExpansionInfo expansionInfo;
if (failed(expansionInfo.compute(
- linalgOp, fusableOpOperand,
- isExpanding ? expandingReshapeOp.getReassociationMaps()
- : collapsingReshapeOp.getReassociationMaps(),
- expandedType.getShape(), collapsedType.getShape(), rewriter)))
- return std::nullopt;
-
- // TODO: With the support of multiple dynamic dims expansion in
- // tensor.expand_shape op, this case can be handled.
- if (failed(validateDynamicDimExpansion(linalgOp, expansionInfo, rewriter)))
- return std::nullopt;
-
- if (failed(isLinalgOpExpandable(linalgOp, expansionInfo, rewriter)))
+ linalgOp, fusableOpOperand, reassociationIndices,
+ expandedShape, rewriter)))
return std::nullopt;
SmallVector<AffineMap, 4> expandedOpIndexingMaps = llvm::to_vector<4>(
@@ -866,15 +824,16 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
expandedOpOperands.reserve(linalgOp.getNumDpsInputs());
for (OpOperand *opOperand : linalgOp.getDpsInputOperands()) {
if (opOperand == fusableOpOperand) {
- expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.getSrc()
- : collapsingReshapeOp.getSrc());
+ expandedOpOperands.push_back(src);
continue;
}
if (auto opOperandType =
dyn_cast<RankedTensorType>(opOperand->get().getType())) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
- RankedTensorType expandedOperandType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOperandShape;
+ RankedTensorType expandedOperandType;
+ std::tie(expandedOperandShape, expandedOperandType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOperandType != opOperand->get().getType()) {
// Reshape the operand to get the right type.
SmallVector<ReassociationIndices> reassociation =
@@ -888,7 +847,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOperandType, opOperand->get(), reassociation));
+ loc, expandedOperandType, opOperand->get(), reassociation,
+ expandedOperandShape));
continue;
}
}
@@ -899,8 +859,10 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
for (OpOperand &opOperand : linalgOp.getDpsInitsMutable()) {
AffineMap indexingMap = linalgOp.getMatchingIndexingMap(&opOperand);
auto opOperandType = cast<RankedTensorType>(opOperand.get().getType());
- RankedTensorType expandedOutputType =
- getExpandedType(opOperandType, indexingMap, expansionInfo);
+ SmallVector<OpFoldResult> expandedOutputShape;
+ RankedTensorType expandedOutputType;
+ std::tie(expandedOutputShape, expandedOutputType) =
+ getExpandedShapeAndType(opOperandType, indexingMap, expansionInfo);
if (expandedOutputType != opOperand.get().getType()) {
SmallVector<ReassociationIndices> reassociation =
getReassociationForExpansion(indexingMap, expansionInfo);
@@ -913,7 +875,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp,
/*isExpandingReshape=*/true)))
return std::nullopt;
outputs.push_back(rewriter.create<tensor::ExpandShapeOp>(
- loc, expandedOutputType, opOperand.get(), reassociation));
+ loc, expandedOutputType, opOperand.get(), reassociation,
+ expandedOutputShape));
} else {
outputs.push_back(opOperand.get());
}
diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
index ef853e4d662a7..57904f912a35b 100644
--- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir
+++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir
@@ -30,20 +30,14 @@ func.func @generic_op_reshape_producer_fusion(%arg0 : tensor<?x?x4x?xf32>,
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x4x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
-// CHECK: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_1]], %[[C4]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM]], %[[DIM_0]], %[[VAL_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_3:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?x?xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_4]], %[[C4]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_2]], %[[DIM_3]], %[[VAL_1]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x4x?xf32>
+// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x4x?xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1], [2, 3]] output_shape [%[[DIM_1]], %[[DIM]], %[[DIM_0]], 4] : tensor<?x?x?xf32> into tensor<?x?x?x4xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -88,21 +82,9 @@ func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor<?x?xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index
-// CHECK: %[[C20:.+]] = arith.constant 20 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: %[[C0:.+]] = arith.constant 0 : index
-// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_0:.+]] = arith.divsi %[[DIM_0]], %[[C20]] : index
-// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM]], 4, %[[VAL_0]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_1:.+]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_2:.+]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_1:.+]] = arith.divsi %[[DIM_2]], %[[C20]] : index
-// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_1]], 4, %[[VAL_1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
-// CHECK: %[[DIM_4:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?xf32>
-// CHECK: %[[DIM_5:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?xf32>
-// CHECK: %[[VAL_2:.+]] = arith.divsi %[[DIM_5]], %[[C20]] : index
-// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[DIM_4]], 4, %[[VAL_2]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
+// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] {{\[\[}}0], [1, 2, 3]] output_shape [%[[SZ0]], 4, %[[SZ1]], 5] : tensor<?x?xf32> into tensor<?x4x?x5xf32>
// CHECK: %[[T3:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]]
// CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"]
@@ -137,26 +119,9 @@ func.func @reshape_as_consumer_permutation
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[SZ0:.+]]: index, %[[SZ1:.+]]: index, %[[SZ2:.+]]: index
-// CHECK: %[[C12:.+]] = arith.constant 12 : index
-// CHECK: %[[C2:.+]] = arith.constant 2 : index
-// CHECK: %[[C1:.+]] = arith.constant 1 : index
-// CHECK: ...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
09fa404
to
0e134fc
Compare
Please review the last commit. The first commit is part of #130874 |
if (failed(moveValueDefinitions( | ||
rewriter, expandingReshapeOp.getOutputShape(), linalgOp))) | ||
return std::nullopt; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be problematic because after moveValueDefinitions
mutates the IR, the rewrite pattern is no longer allowed to return failure()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The usage of reshapeLikeShapesAreCompatible
has a similar problem below (prior to this PR). So maybe this isn't that big of a deal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, many places that happens in the IR. I dont have a good solution for this though. In general its just changing position of some tensor.dim
operations which shouldnt matter too much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me, just few nits!
…dynamic shapes. With `tensor.expand_shape` allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension. Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
0e134fc
to
dbaa97a
Compare
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
With
tensor.expand_shape
allowing expanding dynamic dimension into multiple dynamic dimension, adapt the reshape propagation through expansion to handle cases where one dynamic dimension is expanded into multiple dynamic dimension.