Skip to content

[mlir] [linalg] Add pattern to swap transpose with broadcast #97063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 23, 2024

Conversation

cxy-1993
Copy link
Contributor

@cxy-1993 cxy-1993 commented Jun 28, 2024

Add a pattern that implement:

transpose(broadcast(input)) -> broadcast(transpose(input))

@llvmbot
Copy link
Member

llvmbot commented Jun 28, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: donald chen (cxy-1993)

Changes

Add canonicalize pattern that implement canonicalize:

transpose(broadcast(input)) -> broadcast(transpose(input))

Reduce the cost of transpose.


Full diff: https://github.com/llvm/llvm-project/pull/97063.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/IndexingUtils.h (+8)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+60-1)
  • (modified) mlir/lib/Dialect/Utils/IndexingUtils.cpp (+26)
  • (modified) mlir/test/Dialect/Linalg/canonicalize.mlir (+52-1)
diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
index b774359552aa5..6428409889179 100644
--- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h
@@ -243,6 +243,14 @@ SmallVector<int64_t>
 computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
                          ArrayRef<int64_t> desiredPositions);
 
+/// Returns a permutation vector that remove the result position in
+/// removePositions from inputPerm.
+///
+/// For example, inputPerm = {2, 4, 0, 1, 3} and removePositions = {1, 2} would
+/// result in a {2, 0, 1} permutation vector.
+SmallVector<int64_t> removePermutation(ArrayRef<int64_t> inputPerm,
+                                       ArrayRef<int64_t> removePositions);
+
 /// Helper to return a subset of `arrayAttr` as a vector of int64_t.
 // TODO: Port everything relevant to DenseArrayAttr and drop this util.
 SmallVector<int64_t> getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 57d126603ebd7..9e0ac5354139f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1890,9 +1890,68 @@ struct FoldTransposeWithTranspose : OpRewritePattern<linalg::TransposeOp> {
   }
 };
 
+/// This pattern reduces the cost of transpose by swapping the order of
+/// broadcast and transpose:
+///   transpose(broadcast(input)) -> broadcast(transpose(input))
+struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> {
+  using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
+                                PatternRewriter &rewriter) const override {
+    Value input = transposeOp.getInput();
+    BroadcastOp broadcastOp = input.getDefiningOp<BroadcastOp>();
+    if (!input.hasOneUse() || !broadcastOp)
+      return failure();
+
+    ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+    ArrayRef<int64_t> perms = transposeOp.getPermutation();
+
+    // Get new perms and new dimensions.
+    SmallVector<int64_t> resultPerms = removePermutation(perms, dimensions);
+    SmallVector<int64_t> resultDimensions;
+    SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
+    for (unsigned i = 0; i < dimensions.size(); i++) {
+      resultDimensions.push_back(invertPerm[dimensions[i]]);
+    }
+    llvm::sort(resultDimensions);
+
+    // Create transpose result.
+    Value broadcastInput = broadcastOp.getInput();
+    Location loc = transposeOp.getLoc();
+    MLIRContext *ctx = transposeOp.getContext();
+    SmallVector<OpFoldResult> dims;
+    auto broadcastInputTy =
+        mlir::cast<RankedTensorType>(broadcastInput.getType());
+    for (unsigned i = 0; i < broadcastInputTy.getRank(); i++) {
+      if (broadcastInputTy.isDynamicDim(i)) {
+        dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i)
+                           ->getResult(0));
+      } else {
+        dims.push_back(IntegerAttr::get(IndexType::get(ctx),
+                                        broadcastInputTy.getDimSize(i)));
+      }
+    }
+    SmallVector<OpFoldResult> transposeResultShapes =
+        applyPermutation(dims, resultPerms);
+    Value transposeInit = rewriter.create<tensor::EmptyOp>(
+        transposeOp.getLoc(), transposeResultShapes,
+        broadcastInputTy.getElementType());
+
+    // Create broadcast(transpose(input)).
+    Value transposeResult =
+        rewriter
+            .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit,
+                                 resultPerms)
+            ->getResult(0);
+    rewriter.replaceOpWithNewOp<BroadcastOp>(
+        transposeOp, transposeResult, transposeOp.getInit(), resultDimensions);
+    return success();
+  }
+};
+
 void TransposeOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                               MLIRContext *context) {
-  results.add<FoldTransposeWithTranspose>(context);
+  results.add<FoldTransposeWithTranspose, SwapTransposeWithBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
index aba225be720c3..d1822a3f1f95f 100644
--- a/mlir/lib/Dialect/Utils/IndexingUtils.cpp
+++ b/mlir/lib/Dialect/Utils/IndexingUtils.cpp
@@ -252,6 +252,32 @@ mlir::computePermutationVector(int64_t permSize, ArrayRef<int64_t> positions,
   return res;
 }
 
+SmallVector<int64_t>
+mlir::removePermutation(ArrayRef<int64_t> inputPerm,
+                        ArrayRef<int64_t> removePositions) {
+  assert(inputPerm.size() >= removePositions.size() &&
+         "expect inputPerm size large than position to remove");
+  SmallVector<int64_t> res;
+  for (unsigned inputIndex = 0; inputIndex < inputPerm.size(); inputIndex++) {
+    int64_t targetIndex = inputPerm[inputIndex];
+    bool shouldRemove = false;
+    for (unsigned removeIndex = 0; removeIndex < removePositions.size();
+         removeIndex++) {
+      if (removePositions[removeIndex] == inputPerm[inputIndex]) {
+        shouldRemove = true;
+        break;
+      }
+      if (removePositions[removeIndex] < inputPerm[inputIndex]) {
+        targetIndex--;
+      }
+    }
+    if (!shouldRemove) {
+      res.push_back(targetIndex);
+    }
+  }
+  return res;
+}
+
 SmallVector<int64_t> mlir::getI64SubArray(ArrayAttr arrayAttr,
                                           unsigned dropFront,
                                           unsigned dropBack) {
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
index 928030a81dc02..30a8d76fc73ac 100644
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
   return %0 : tensor<2x3xf32>
 }
 
-// ----
+// -----
 
 func.func @transpose_1d(%input: tensor<16xf32>,
                         %init: tensor<16xf32>) -> tensor<16xf32> {
@@ -1096,3 +1096,54 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
   func.return %transpose2 : tensor<3x4x5xf32>
 }
 
+// -----
+
+func.func @broadcast_transpose_fold(%input: tensor<2x4x5xf32>,
+                                    %init1: tensor<1x2x3x4x5x6xf32>,
+                                    %init2: tensor<1x3x2x6x5x4xf32>) -> tensor<1x3x2x6x5x4xf32> {
+  // CHECK-LABEL: @broadcast_transpose_fold
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x2x6x5x4xf32>
+  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
+  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x3x2x6x5x4xf32>) dimensions = [0, 1, 3]
+  //       CHECK:   return %[[BROADCAST]] : tensor<1x3x2x6x5x4xf32>
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<2x4x5xf32>)
+      outs(%init1 : tensor<1x2x3x4x5x6xf32>)
+      dimensions = [0, 2, 5]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<1x2x3x4x5x6xf32>)
+      outs(%init2 : tensor<1x3x2x6x5x4xf32>)
+      permutation = [0, 2, 1, 5, 4, 3]
+  func.return %transpose : tensor<1x3x2x6x5x4xf32>
+}
+
+// -----
+
+func.func @broadcast_transpose_fold_dynamic(%input: tensor<?x?x5xf32>,
+                                            %init1: tensor<1x?x3x?x5x6xf32>,
+                                            %init2: tensor<1x3x?x6x5x?xf32>) -> tensor<1x3x?x6x5x?xf32> {
+  // CHECK-LABEL: @broadcast_transpose_fold_dynamic
+  //  CHECK-SAME:     %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
+  //  CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
+  //  CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
+  //   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+  //   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+  //       CHECK:   %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
+  //       CHECK:   %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
+  //       CHECK:   %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
+  //       CHECK:   %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
+  //       CHECK:   %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
+  //       CHECK:   return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
+  %broadcast = linalg.broadcast
+      ins(%input : tensor<?x?x5xf32>)
+      outs(%init1 : tensor<1x?x3x?x5x6xf32>)
+      dimensions = [0, 2, 5]
+  %transpose = linalg.transpose
+      ins(%broadcast : tensor<1x?x3x?x5x6xf32>)
+      outs(%init2 : tensor<1x3x?x6x5x?xf32>)
+      permutation = [0, 2, 3, 5, 4, 1]
+  func.return %transpose : tensor<1x3x?x6x5x?xf32>
+}

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this makes sense as something good to do always, but isnt this a one-off pattern that you want. Does it have to be a canonicalization where it needs a fixed point? (not hard-blocking though, cause I can see why this is better in general). (see the extremely long discussion here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355)

@cxy-1993
Copy link
Contributor Author

cxy-1993 commented Jul 1, 2024

I understand this makes sense as something good to do always, but isnt this a one-off pattern that you want. Does it have to be a canonicalization where it needs a fixed point? (not hard-blocking though, cause I can see why this is better in general). (see the extremely long discussion here https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355)

Thanks for the context. I completely agree with the discussion here: I also often struggle with whether an optimization pattern should be placed in fold/canonicalize or a pass. Maybe you should doc it here too(https://mlir.llvm.org/docs/Canonicalization/).

For this pattern, I think it's appropriate to use it as a canonicalization. Here are the reasons:

  • This pattern is convergent.
  • This pattern is unidirectional (in the direction of reducing data size, not like the case in the RFC).
  • This is not a one-off pattern; new matches may be generated during the application process.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Good to me for having as a canonicalization but under the caveat that if there is a valid case to move it out if canonicalization then we do that
Note that such a move would cause downstream breakage. Something that was meant to happen automatically will change to it being opt in. This is another downside of canonicalization for me. Downstream uses can be unstable

@cxy-1993
Copy link
Contributor Author

cxy-1993 commented Jul 4, 2024

Thanks! Good to me for having as a canonicalization but under the caveat that if there is a valid case to move it out if canonicalization then we do that Note that such a move would cause downstream breakage. Something that was meant to happen automatically will change to it being opt in. This is another downside of canonicalization for me. Downstream uses can be unstable

If this pattern breaks downstream repo, please feel free to remove it.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an interesting canonicalization, indeed!

Two side effects we should watch out for:

  1. This may lead to not only lit test failures but also to performance regressions if we turn a transpose with a highly optimized lowering into a simpler transpose that falls into a default lowering path.
  2. We may get some performance improvements if we get rid of the transpose in transpose(broadcast(1D->2D)) (assuming the transpose is not already optimized away somehow). We should add a test for this scenario.

An interesting follow-up would be to consider scenarios where a type conversion happens between the broadcast and the transpose.

/// removePositions from inputPerm.
///
/// For example, inputPerm = {2, 4, 0, 1, 3} and removePositions = {1, 2} would
/// result in a {2, 0, 1} permutation vector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the result correct? Shouldn't this be {2, 1, 3} or am I getting this wrong?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry not make this function clear.
This function returns a new permutation after removing input position in removePositions.

The removed position is "2", "1" in input pos, after remove, we have {4, 0, 3}.
To be a valid permutation, returned perm should start from "0", result should be {2, 0, 1}.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think we should already have this functionality implemented on AffineMap. Would you mind taking a look at the utilities in AffineMap.h? There are some drop... methods might get you what you need. You can get an AffineMap from a permutation with: https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/IR/AffineMap.h#L103

Copy link
Contributor Author

@cxy-1993 cxy-1993 Jul 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, this approach is very interesting (at least my function name is not precise enough, I think dropDims would be better). We can indeed replace the calculation of perm with the calculation of affine map using applyPermutationMap and getPermutationMap. Before making the changes, I would like to discuss some points: This will simplify the calculation of perm, but it will introduce more compilation time -- we have to construct affine maps to reuse the affine map operation functions. Is this a more reasonable approach? If so, all the util functions in the Permutation utils series have corresponding util functions in the affine map. Should we systematically replace them all?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. My comment was motivated by the high proliferation of AffineMap and permutation utilities over the years (to the point that sometimes it's a challenge, even for people familiar with the code, to figure out if something exists already). However, I think this adding this one is justified as it's combined with other utilities that work on permutations.

@cxy-1993 cxy-1993 force-pushed the cxy-fold branch 2 times, most recently from d1e2439 to 28cbea9 Compare July 5, 2024 05:06
@cxy-1993
Copy link
Contributor Author

cxy-1993 commented Jul 5, 2024

This looks like an interesting canonicalization, indeed!

Two side effects we should watch out for:

  1. This may lead to not only lit test failures but also to performance regressions if we turn a transpose with a highly optimized lowering into a simpler transpose that falls into a default lowering path.
  2. We may get some performance improvements if we get rid of the transpose in transpose(broadcast(1D->2D)) (assuming the transpose is not already optimized away somehow). We should add a test for this scenario.

An interesting follow-up would be to consider scenarios where a type conversion happens between the broadcast and the transpose.

Thanks for the comment. I have considered the issue you mentioned. If there is a performance improvement for some special permutation transpose, then using expand shape should be a better solution. Of course, if this pattern breaks downstream repos, please feel free to discuss and remove it.

@cxy-1993 cxy-1993 requested a review from dcaballe July 5, 2024 05:11
@cxy-1993 cxy-1993 force-pushed the cxy-fold branch 2 times, most recently from bcf4997 to 7a64804 Compare July 7, 2024 00:53
Copy link

github-actions bot commented Jul 7, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

…cast

Add canonicalize pattern that implement canonicalize:

  transpose(broadcast(input)) -> broadcast(transpose(input))

Reduce the cost of transpose.
@stellaraccident
Copy link
Contributor

Add canonicalize pattern that implement canonicalize:

transpose(broadcast(input)) -> broadcast(transpose(input))

Reduce the cost of transpose.

I'm not opposed to this being a canonicalization outright, but -- transpose and broadcast have no cost on their own/abstractly and justifying it based on that puts this in the realm of an optimization, not a canonicalization.

With that said, if we want to say that this is a canonical form, that may be an ok thing to do, but we need to define it in terms of the lattice of all such ops. As Diego says above, there is often an interplay between several ops of this category and I'd prefer we not add a one off pattern without a better justification and analysis of what it means for all of them.

No matter what we do, let's not land such a canonicalization with a justification that it is an optimization (ie. Which is what the pr description states as the justification). I'm marking changes requested on that point and am happy to be overridden on that if the consensus is that this makes sense in terms of the lattice versus as a piece of a specific optimization strategy.

In general, there are entire optimization pipelines that are responsible for propagating various classes of layout and data movement operations. Such things are non trivial and are better done as an actual algorithm that can be controlled/cost modeled/etc.

@stellaraccident stellaraccident self-requested a review July 7, 2024 02:14
Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As noted in my above comment, if there is a consensus that this is a canonicalization, please include that justification in the pr description vs justifying as an optimization. Marking change requested for that point. It seems to me that making this judgment would be part of a bit more analysis of adjacent ops vs a one off, but I won't stick on that point if others believe this is justified on its own.

(I do tend to agree this might be reasonable to do but the bar for calling something a canonicalization is high, and I'd like to see it justified and documented in such a way. Some of that commentary is presented inline and needs to be moved to the PR description and agreed on)

@cxy-1993
Copy link
Contributor Author

cxy-1993 commented Jul 9, 2024

As noted in my above comment, if there is a consensus that this is a canonicalization, please include that justification in the pr description vs justifying as an optimization. Marking change requested for that point. It seems to me that making this judgment would be part of a bit more analysis of adjacent ops vs a one off, but I won't stick on that point if others believe this is justified on its own.

(I do tend to agree this might be reasonable to do but the bar for calling something a canonicalization is high, and I'd like to see it justified and documented in such a way. Some of that commentary is presented inline and needs to be moved to the PR description and agreed on)

Thank you for your suggestion. After receiving multiple concerns about placing this pattern in canonicalize, I have reconsidered the validity of this pattern in canonicalize. For some backends, doing so would indeed introduce unnecessary risks, so I have move this pattern from canonicalize to transform. Thanks again @MaheshRavishankar @stellaraccident @dcaballe

@cxy-1993 cxy-1993 requested a review from stellaraccident July 9, 2024 13:01
@cxy-1993
Copy link
Contributor Author

ping @MaheshRavishankar @stellaraccident :)

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Looks reasonable to me. I wish we had a better grouping than a "swapTransposeWithBroadcast" kind of thing.. but I dont have a better idea.

I have no concerns here. Ill let @dcaballe finish his review of the implementation.

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lgtm

Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After receiving multiple concerns about placing this pattern in canonicalize, I have reconsidered the validity of this pattern in canonicalize

can we elaborate on this? I am concerned about a proliferation of random patterns exposed by various APIs that aren’t canonicalization when they could be: can we record these « concerns » as a rationale why this is not a good canonicalization?

@dcaballe
Copy link
Contributor

Agree with Stella that this transformation, as a canonicalization, shouldn’t be driven by cost or performance considerations. The motivating factor should be to define a canonical order between these memory/data layout operations so that we can reduce the number of op combinations to be handled in other transformations. On top of that, transposing broadcasted data is transposing duplicated data so this transformation should make the IR less redundant, which is something that falls into the canonicalization territory, IMO.

Defining a canonical order around memory/layout operations has been a long missing gap in our representation, resulting in downstreams adopting as “canonical” form whichever random order frontends or other transformations generated. This is suboptimal because these assumptions make our transformations fragile and we are more hesitant to accept changes that might alter these assumptions.

Should this canonical representation be suitable for all the transformations? Not really. Each transformation may have a pre-processing step where the memory/layout operations can be reordered as needed. Even those pre-processing transformations would be simpler if they start from a canonical form.

IMO, it would be hard to define and integrate a canonical form for the full lattice in one shot. However, this case looks like a no-brainer to me and it’s an incremental step towards addressing the current situation.

@MaheshRavishankar
Copy link
Contributor

The real issue here is that this pattern is a one-off. It doesn't fit within anything else. It's not a pattern in service of a larger transformation goal. So appears like a zombie method.

@cxy-1993
Copy link
Contributor Author

After receiving multiple concerns about placing this pattern in canonicalize, I have reconsidered the validity of this pattern in canonicalize

can we elaborate on this? I am concerned about a proliferation of random patterns exposed by various APIs that aren’t canonicalization when they could be: can we record these « concerns » as a rationale why this is not a good canonicalization?

I think this is a very good example to add to @MaheshRavishankar discussion of whether an optimization should use canonicalization. When we have a consensus, I will update the documentation to standardize future behavior.

My previous concern was that putting this optimization in canonicalization might not provide positive benefits for all backends. However, as dcaballe and stellaraccident mentioned, the goal of canonicalization should not be to achieve maximum performance on every backend. The goal of canonicalization is to make subsequent optimizations more effective. (This should be the original definition of canonicalization https://github.com/w3c/charmod-norm/blob/gh-pages/index.html, and we should at least agree on this point).

According to this definition, if we define the direction of the lattice changed by canonicalization to reduce redundant data in the IR, I think this pattern is appropriate behavior for canonicalization, because as I mentioned in my previous comments, this pattern is convergent and unidirectional, and it is not a one-off pattern --- we may have other patterns that will generate new opportunities for this pattern.

The real issue here is that this pattern is a one-off. It doesn't fit within anything else. It's not a pattern in service of a larger transformation goal. So appears like a zombie method.

For example, we can combine other data redundant reduce pattern to reach:
reduce(transpose(broadcast(input))) - > broadcast(transpose(reduce(input)))

Based on the discussion above, I believe that this pattern can be considered canonicalization, but we need to define the lattice and the direction of canonicalization carefully. Can we reach a consensus on this? @stellaraccident @MaheshRavishankar @dcaballe

@stellaraccident
Copy link
Contributor

Thank you. That is the kind of justification I was looking for. Let's include that in the pr description and agree on it.

(Sorry to add overhead to the process, but this area has gotten out of hand and we are trying to do better)

@cxy-1993
Copy link
Contributor Author

@stellaraccident
Copy link
Contributor

I have updated the docs according the discuss above and https://discourse.llvm.org/t/rfc-update-to-general-design-section-of-operation-canonicalizations-in-mlir/79355. Please help me review it too. @MaheshRavishankar @stellaraccident @dcaballe @stellaraccident @joker-eph

Can you please update the description of this PR to include the rationale? Both you and Diego have a well enunciated view on that and we would like that in PR descriptions so it doesn't get lost in a centithread. Other than that lgtm.

@stellaraccident
Copy link
Contributor

Defining a canonical order around memory/layout operations has been a long missing gap in our representation, resulting in downstreams adopting as “canonical” form whichever random order frontends or other transformations generated. This is suboptimal because these assumptions make our transformations fragile and we are more hesitant to accept changes that might alter these assumptions.

This is well enunciated. Any chance once this patch lands that you could summarize it on discourse to leave breadcrumbs for those who come next. I think its fine to "incremental our way" to a better state, but we want to capture the thought process and plan a little more prominently. Good discussion on this thread on that front.

Comment on lines 38 to 39
canonicalization. But it is generally better to define a canonicalize
pattern that do not harm the performance.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
canonicalization. But it is generally better to define a canonicalize
pattern that do not harm the performance.
canonicalization.

This is ill defined.

%res = vector.shape_cast %0 : vector<nx1x<eltty>> to vector<1xnx<elty>>
```

is not a good canonicalize pattern because it drops the transpose semantic.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit vague of an example, if we want to go there we should find a more obvious example to explain I think.

* Canonicalize isn't a great place to put pattens with expensive compile time
(i.e. have O(n) complexity) or complicated cost models.

* Canonicalize shouldn't drop the semantic of original operation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Canonicalize shouldn't drop the semantic of original operation.
* Canonicalize shouldn't lose the semantic of original operation: the original information should always be recoverable from the transformed IR.

@@ -33,6 +33,11 @@ together.

Some important things to think about w.r.t. canonicalization patterns:

* The goal of canonicalization is to make subsequent optimizations more
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* The goal of canonicalization is to make subsequent optimizations more
* The goal of canonicalization is to make subsequent analyses and optimizations more


is a good canonicalize pattern because:

1. This pattern is converge.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1. This pattern is converge.

This does not read well: "converge" is a verb.
Also the notion of convergence in canonicalization to me isn't about a single pattern but about all the patterns together (that is it comes back to the implicit lattice defined by the canonicalization).

2. This pattern always transforms the program towards reducing the amount of
computational data, which is a clear lattice.
3. This is not a one-off pattern, new matches may be generated during the
application process.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this sentence (nor how it applies to what makes it a good pattern)

computational data, which is a clear lattice.
3. This is not a one-off pattern, new matches may be generated during the
application process.

## Globally Applied Rules

These transformations are applied to all levels of IR:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for improving the doc! Can you split this out in a separate PR please?

@@ -51,6 +56,60 @@ Some important things to think about w.r.t. canonicalization patterns:
* It is always good to eliminate operations entirely when possible, e.g. by
folding known identities (like "x + 0 = x").

* Canonicalize isn't a great place to put pattens with expensive compile time
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Canonicalization. Also maybe running time is a better term because we are talking about the running time of the compiler. (Not the compile time of the C++ pattern implementation.)


is a good canonicalize pattern because:

1. This pattern is converge.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rewrite process always converges, because there is no canonicalization pattern that performs the inverse transformation.


1. This pattern is converge.
2. This pattern always transforms the program towards reducing the amount of
computational data, which is a clear lattice.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note quite sure what this part means which is a clear lattice. Maybe just drop it, "reducing the amount of computation" makes sense to me in this example.

1. This pattern is converge.
2. This pattern always transforms the program towards reducing the amount of
computational data, which is a clear lattice.
3. This is not a one-off pattern, new matches may be generated during the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What exactly does one-off mean? Does it mean that this is a kind of canonicalization does is not related to any of the other existing canonicalizations? If so, why is that a bad thing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good q. I think I've seen people using this two ways: a. Not part of a holistic approach to reducing all forms of the implicated ops, and b. Does not benefit from the overhead of being run in a fix point loop (often to say this is a lowering ala dialect conversion).

In this case, I believe the criticism was (a) but the author may have thought (b).

It's not that it isn't related to any other canonicalization but that it isn't part of a holistic design.

I would probably split this into two:

  • it is part of a holistic, consistent design for related forms.
  • it is not a lowering that would be better maintained in a library of one off patterns that can be included in such pipelines.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code LGTM. I would move the doc changes to a separate PR, as suggested.

/// removePositions from inputPerm.
///
/// For example, inputPerm = {2, 4, 0, 1, 3} and removePositions = {1, 2} would
/// result in a {2, 0, 1} permutation vector.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. My comment was motivated by the high proliferation of AffineMap and permutation utilities over the years (to the point that sometimes it's a challenge, even for people familiar with the code, to figure out if something exists already). However, I think this adding this one is justified as it's combined with other utilities that work on permutations.

SmallVector<int64_t> resultPerms = dropDims(perms, dimensions);
SmallVector<int64_t> invertPerm = invertPermutationVector(perms);
SmallVector<int64_t> resultDimensions;
for (unsigned i = 0; i < dimensions.size(); i++) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move ub to var + use pre-increment + remove curly braces for single statement for per coding guidelines

assert(inputPerm.size() >= dropPositions.size() &&
"expect inputPerm size large than position to drop");
SmallVector<int64_t> res;
for (unsigned inputIndex = 0; inputIndex < inputPerm.size(); ++inputIndex) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@cxy-1993
Copy link
Contributor Author

Thanks for the comments! Based on the comment's suggestion, the documentation part of this PR has been moved to #99753

@cxy-1993
Copy link
Contributor Author

The code LGTM. I would move the doc changes to a separate PR, as suggested.

Nice catch! I have modify the code according to your comment and move the doc to #99753 .

@cxy-1993 cxy-1993 requested a review from dcaballe July 20, 2024 10:34
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for bearing with us :)

@cxy-1993 cxy-1993 merged commit 9cc11b9 into llvm:main Jul 23, 2024
7 checks passed
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Add a pattern that implement:

  transpose(broadcast(input)) -> broadcast(transpose(input))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants