-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Improve compose expand(collapse) pattern #117768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor Author: Ian Wood (IanWood1) ChangesIf expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single Full diff: https://github.com/llvm/llvm-project/pull/117768.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..0357e34a2e0963 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
- if (srcType == resultType)
+ if (srcRank == resultRank)
return failure();
auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
resultShape.slice(resultIndices.front(), resultIndices.size());
if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape == resultSubShape &&
- llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
- composedReassociation.push_back(srcIndices);
- } else {
+ if (srcSubShape != resultSubShape ||
+ llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
return std::nullopt;
}
+ for (auto dim : llvm::seq<int64_t>(0, srcSubShape.size())) {
+ ReassociationIndices reassoc;
+ reassoc.push_back(srcIndices.front() + dim);
+ composedReassociation.push_back(reassoc);
+ }
+ continue;
}
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +407,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
return std::nullopt;
// Remap the subshape indices back to the original srcShape.
- for (auto &subshape_indices : *subShapeReassociation) {
- ReassociationIndices shape_indices;
- for (int64_t index : subshape_indices)
- shape_indices.push_back(srcIndices.front() + index);
- composedReassociation.push_back(shape_indices);
+ for (auto &subshapeIndices : *subShapeReassociation) {
+ ReassociationIndices shapeIndices;
+ for (int64_t index : subshapeIndices)
+ shapeIndices.push_back(srcIndices.front() + index);
+ composedReassociation.push_back(shapeIndices);
}
}
return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
// -----
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+ return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: [0], [1], [2], [3, 4]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+ return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: [0], [1], [2], [3, 4]
+// CHECK: return %[[RESULT]]
+
+// -----
+
// CHECK-LABEL: func @zero_rank_reshape_multi
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
|
@llvm/pr-subscribers-mlir Author: Ian Wood (IanWood1) ChangesIf expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single Full diff: https://github.com/llvm/llvm-project/pull/117768.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
index 89bc57f09ec8ba..0357e34a2e0963 100644
--- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h
@@ -338,7 +338,7 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
int64_t srcRank = srcType.getRank();
int64_t resultRank = resultType.getRank();
- if (srcType == resultType)
+ if (srcRank == resultRank)
return failure();
auto srcReassociation = collapseOp.getReassociationIndices();
@@ -388,12 +388,16 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
resultShape.slice(resultIndices.front(), resultIndices.size());
if (srcSubShape.size() == resultSubShape.size()) {
- if (srcSubShape == resultSubShape &&
- llvm::count_if(srcSubShape, ShapedType::isDynamic) < 2) {
- composedReassociation.push_back(srcIndices);
- } else {
+ if (srcSubShape != resultSubShape ||
+ llvm::count_if(srcSubShape, ShapedType::isDynamic) >= 2) {
return std::nullopt;
}
+ for (auto dim : llvm::seq<int64_t>(0, srcSubShape.size())) {
+ ReassociationIndices reassoc;
+ reassoc.push_back(srcIndices.front() + dim);
+ composedReassociation.push_back(reassoc);
+ }
+ continue;
}
// Find reassociation to collapse `srcSubShape` into `resultSubShape`.
@@ -403,11 +407,11 @@ struct ComposeExpandOfCollapseOp : public OpRewritePattern<ExpandOpTy> {
return std::nullopt;
// Remap the subshape indices back to the original srcShape.
- for (auto &subshape_indices : *subShapeReassociation) {
- ReassociationIndices shape_indices;
- for (int64_t index : subshape_indices)
- shape_indices.push_back(srcIndices.front() + index);
- composedReassociation.push_back(shape_indices);
+ for (auto &subshapeIndices : *subShapeReassociation) {
+ ReassociationIndices shapeIndices;
+ for (int64_t index : subshapeIndices)
+ shapeIndices.push_back(srcIndices.front() + index);
+ composedReassociation.push_back(shapeIndices);
}
}
return {std::move(composedReassociation)};
diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir
index 0b54c207dea84e..613ec066337294 100644
--- a/mlir/test/Dialect/Tensor/canonicalize.mlir
+++ b/mlir/test/Dialect/Tensor/canonicalize.mlir
@@ -1382,6 +1382,34 @@ func.func @compose_expand_of_collapse_0_rank_to_collapse(%arg0 : tensor<1x1x1x1x
// -----
+func.func @compose_expand_of_collapse_static(%arg0 : tensor<4x32x10x64x2xf16>) -> tensor<4x32x10x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x32x10x64x2xf16> into tensor<128x10x128xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, 32, 10, 128] : tensor<128x10x128xf16> into tensor<4x32x10x128xf16>
+ return %expanded : tensor<4x32x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_static
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x32x10x64x2xf16>
+// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: [0], [1], [2], [3, 4]
+// CHECK: return %[[RESULT]]
+
+// -----
+
+func.func @compose_expand_of_collapse_dynamic(%arg0 : tensor<4x?x10x64x2xf16>, %arg1 : index) -> tensor<4x?x10x128xf16> {
+ %collapsed = tensor.collapse_shape %arg0 [[0, 1], [2], [3, 4]] : tensor<4x?x10x64x2xf16> into tensor<?x10x128xf16>
+ %expanded = tensor.expand_shape %collapsed [[0, 1], [2], [3]] output_shape [4, %arg1, 10, 128] : tensor<?x10x128xf16> into tensor<4x?x10x128xf16>
+ return %expanded : tensor<4x?x10x128xf16>
+}
+
+// CHECK-LABEL: func @compose_expand_of_collapse_dynamic
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x10x64x2xf16>
+// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[ARG0]]
+// CHECK-SAME: [0], [1], [2], [3, 4]
+// CHECK: return %[[RESULT]]
+
+// -----
+
// CHECK-LABEL: func @zero_rank_reshape_multi
func.func @zero_rank_reshape_multi(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: return %arg0
|
c665b88
to
b4ff5f2
Compare
If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Signed-off-by: Ian Wood <[email protected]>
If expand(collapse) has a dimension that gets collapsed and then expanded to the same shape, the pattern would fail to canonicalize this to a single collapse shape. Line 341 was changed because the expand(collapse) could be a reinterpret-cast like sequence where the shapes differ but the rank is the same. This cannot be represented by a single
collapse_shape
op.