-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][mesh] adding option for traversal order in sharding propagation #144079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
FYI @yaochengji |
@llvm/pr-subscribers-mlir Author: Frank Schlimbach (fschlimb) ChangesThe traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order
Default is the previous behavior (backward-forward). FYI @tkarna Full diff: https://github.com/llvm/llvm-project/pull/144079.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae..a2424d43a8ba9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,6 +19,18 @@ class FuncOp;
namespace mesh {
+/// This enum controls the traversal order for the sharding propagation.
+enum class TraversalOrder {
+ /// Forward traversal.
+ Forward,
+ /// Backward traversal.
+ Backward,
+ /// Forward then backward traversal.
+ ForwardBackward,
+ /// Backward then forward traversal.
+ BackwardForward
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 06ebf151e7d64..11ec7e78cd5e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
operation, and the operations themselves are added with sharding option
attributes.
}];
+ let options = [
+ Option<"traversal", "traversal",
+ "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+ "Traversal order to use for sharding propagation:",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+ "Forward only traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+ "backward only traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+ "forward-backward traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+ "backward-forward traversal.")
+ )}]>,
+ ];
let dependentDialects = [
"mesh::MeshDialect"
];
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..16cdbebf91900 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -18,6 +18,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LocationSnapshot.h"
+// #include "mlir/Transforms/ShardingPropagationUtils.h"
#include "mlir/Transforms/ViewOpGraph.h"
#include "llvm/Support/Debug.h"
#include <limits>
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4452dd65fce9d..9d4a144912ee2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
//===----------------------------------------------------------------------===//
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+
+ using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
+
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
@@ -383,17 +386,34 @@ struct ShardingPropagation
});
// 1. propagate in reversed order
- for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
-
- LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
- << funcOp << "\n");
- LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+ if (traversal == TraversalOrder::Backward ||
+ traversal == TraversalOrder::BackwardForward) {
+ for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
+ if (traversal == TraversalOrder::BackwardForward) {
+ LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
+ << funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+ }
+ }
// 2. propagate in original order
- for (Operation &op : llvm::make_early_inc_range(block))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
+ if (traversal != TraversalOrder::Backward) {
+ for (Operation &op : llvm::make_early_inc_range(block))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
+ if (traversal == TraversalOrder::ForwardBackward) {
+ LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
+ << funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+ }
+ }
+
+ // 3. propagate in backward order if needed
+ if (traversal == TraversalOrder::ForwardBackward)
+ for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..98e9931b8de94
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+ mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+ %0 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[v1:%.*]] = linalg.fill ins
+ // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+ %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+ // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+ %3 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+ // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
+ // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+ %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
+ : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+ ^bb0(%in: i32, %in_2: i32, %out: i32):
+ %9 = arith.addi %in, %in_2 : i32
+ linalg.yield %9 : i32
+ } -> tensor<6x6xi32>
+ %c0_i32 = arith.constant 0 : i32
+ %6 = tensor.empty() : tensor<i32>
+ %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+ // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+ // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+ %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
+ (%in: i32, %init: i32) {
+ %9 = arith.addi %in, %init : i32
+ linalg.yield %9 : i32
+ }
+ // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
+ %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+ %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+ return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+ }
+}
|
@llvm/pr-subscribers-mlir-core Author: Frank Schlimbach (fschlimb) ChangesThe traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order
Default is the previous behavior (backward-forward). FYI @tkarna Full diff: https://github.com/llvm/llvm-project/pull/144079.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index 83399d10beaae..a2424d43a8ba9 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,6 +19,18 @@ class FuncOp;
namespace mesh {
+/// This enum controls the traversal order for the sharding propagation.
+enum class TraversalOrder {
+ /// Forward traversal.
+ Forward,
+ /// Backward traversal.
+ Backward,
+ /// Forward then backward traversal.
+ ForwardBackward,
+ /// Backward then forward traversal.
+ BackwardForward
+};
+
//===----------------------------------------------------------------------===//
// Passes
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 06ebf151e7d64..11ec7e78cd5e6 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,6 +24,21 @@ def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionO
operation, and the operations themselves are added with sharding option
attributes.
}];
+ let options = [
+ Option<"traversal", "traversal",
+ "mlir::mesh::TraversalOrder", /*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
+ "Traversal order to use for sharding propagation:",
+ [{::llvm::cl::values(
+ clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
+ "Forward only traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
+ "backward only traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, "forward-backward",
+ "forward-backward traversal."),
+ clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, "backward-forward",
+ "backward-forward traversal.")
+ )}]>,
+ ];
let dependentDialects = [
"mesh::MeshDialect"
];
diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h
index 41f208216374f..16cdbebf91900 100644
--- a/mlir/include/mlir/Transforms/Passes.h
+++ b/mlir/include/mlir/Transforms/Passes.h
@@ -18,6 +18,7 @@
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/LocationSnapshot.h"
+// #include "mlir/Transforms/ShardingPropagationUtils.h"
#include "mlir/Transforms/ViewOpGraph.h"
#include "llvm/Support/Debug.h"
#include <limits>
diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 4452dd65fce9d..9d4a144912ee2 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,6 +362,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
//===----------------------------------------------------------------------===//
struct ShardingPropagation
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
+
+ using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
+
void runOnOperation() override {
FunctionOpInterface funcOp = getOperation();
MLIRContext *ctx = funcOp.getContext();
@@ -383,17 +386,34 @@ struct ShardingPropagation
});
// 1. propagate in reversed order
- for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
-
- LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
- << funcOp << "\n");
- LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+ if (traversal == TraversalOrder::Backward ||
+ traversal == TraversalOrder::BackwardForward) {
+ for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
+ if (traversal == TraversalOrder::BackwardForward) {
+ LLVM_DEBUG(DBGS() << "After backward order propagation:\n"
+ << funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+ }
+ }
// 2. propagate in original order
- for (Operation &op : llvm::make_early_inc_range(block))
- if (failed(visitOp(&op, builder)))
- return signalPassFailure();
+ if (traversal != TraversalOrder::Backward) {
+ for (Operation &op : llvm::make_early_inc_range(block))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
+ if (traversal == TraversalOrder::ForwardBackward) {
+ LLVM_DEBUG(DBGS() << "After forward order propagation:\n"
+ << funcOp << "\n");
+ LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+ }
+ }
+
+ // 3. propagate in backward order if needed
+ if (traversal == TraversalOrder::ForwardBackward)
+ for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+ if (failed(visitOp(&op, builder)))
+ return signalPassFailure();
}
};
diff --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
new file mode 100644
index 0000000000000..98e9931b8de94
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))" %s | FileCheck %s
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", "MPI:comm_world_rank" = 0 : i32>} {
+ mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
+ func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) attributes {llvm.emit_c_interface} {
+ %c1_i32 = arith.constant 1 : i32
+ // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
+ %0 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[v1:%.*]] = linalg.fill ins
+ // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to [[vsharding_0]] : tensor<6x6xi32>
+ %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> tensor<6x6xi32>
+ %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
+ %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
+ // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard [[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : tensor<6x6xi32>
+ %3 = tensor.empty() : tensor<6x6xi32>
+ // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to [[vsharding_4]] annotate_for_users : tensor<6x6xi32>
+ // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]}
+ // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : tensor<6x6xi32>) {
+ // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}0]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to [[vsharding_6]] : tensor<6x6xi32>
+ %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
+ : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
+ ^bb0(%in: i32, %in_2: i32, %out: i32):
+ %9 = arith.addi %in, %in_2 : i32
+ linalg.yield %9 : i32
+ } -> tensor<6x6xi32>
+ %c0_i32 = arith.constant 0 : i32
+ %6 = tensor.empty() : tensor<i32>
+ %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
+ // CHECK: [[vreduced:%.*]] = linalg.reduce ins
+ // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] partial = sum [0] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to [[vsharding_12]] : tensor<i32>
+ %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) dimensions = [0, 1]
+ (%in: i32, %init: i32) {
+ %9 = arith.addi %in, %init : i32
+ linalg.yield %9 : i32
+ }
+ // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = {{\[\[}}]] : !mesh.sharding
+ %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
+ // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard [[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
+ %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 annotate_for_users : tensor<i32>
+ return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Except for a nit and a question about test coverage, LGTM!
@@ -383,17 +386,34 @@ struct ShardingPropagation | |||
}); | |||
|
|||
// 1. propagate in reversed order |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: make into sentence, i.e. Capitalize and add full stop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, minor comments.
for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block))) | ||
if (failed(visitOp(&op, builder))) | ||
return signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: This loop nest is replicated three times. Could be refactored by creating a queue and, depending on the traversal order, pushing block
or llvm::reverse(block)
to it, and then iterating over the queue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ranges have different template types, so this would be tricker than it seems.
However, I can deduplicate the loop bodies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Deduplicated code.
llvm#144079) The traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order - forward-only - backward-only - forward-backward - backward-forward Default is the previous behavior (backward-forward).
Hi, the test added in this PR was failed in the buildbot. Can you take a look? |
…opagation" (#145531) Reverts #144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
…sharding propagation" (#145531) Reverts llvm/llvm-project#144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
…opagation" (llvm#145531) Reverts llvm#144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
Sorry for the delay, didn't see the notification about the issue or your message. Thank @qinkunbao for handling this in the meanwhile. Hope this doesn't repeat... |
…opagation" (llvm#145531) Reverts llvm#144079 Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140
llvm#144079 introduced a test with an uninitialized access Buildbot failure: https://lab.llvm.org/buildbot/#/builders/164/builds/11140 and got reverted llvm#145531 This PR is an exact copy of llvm#144079 plus a trivial fix (96c8525).
The traversal order in sharding propagation was hard-coded. This PR provides options to the pass to select a suitable order
Default is the previous behavior (backward-forward).
FYI @tkarna